Skip to content

Add relation-aware graph transformer signals#674

Draft
yliu2-sc wants to merge 2 commits into
mainfrom
yliu2/gt-rel
Draft

Add relation-aware graph transformer signals#674
yliu2-sc wants to merge 2 commits into
mainfrom
yliu2/gt-rel

Conversation

@yliu2-sc

@yliu2-sc yliu2-sc commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Scope of work done

This PR adds opt-in relation-aware attention logits to the GiGL Graph Transformer without enabling the graph-edge hard attention mask path.

  • Adds relation_attention_mode="edge_type_bilinear" to GraphTransformerEncoderLayer.
  • Adds zero-initialized per-relation bilinear attention matrices shaped (num_relations, num_heads, head_dim, head_dim).
  • Represents relation edges as sparse (batch_idx, query_pos, key_pos, relation_idx) coordinates.
  • Maps directed graph edges source -> target into attention coordinates as query=target, key=source.
  • Wires relation coordinate construction through heterodata_to_graph_transformer_input.
  • Derives relation order in GraphTransformerEncoder from sorted edge_type_to_feat_dim_map.
  • Adds transform and encoder unit coverage for relation ordering/direction, zero-init equivalence, invalid relation IDs, indexed-pair logit updates, padding exclusion, and existing negative attention masks.

Explicitly out of scope:

  • Graph-edge hard attention masking.
  • Relation value residual gates.
  • Sparse edge-attribute attention bias.
  • Sparse pairwise nonmissing structural-bias indices.

Implementation notes

Default behavior remains unchanged unless relation_attention_mode="edge_type_bilinear" is configured.

Relation-aware logits initialize to zero, preserving baseline outputs at initialization.

Sparse relation coordinates are built before relation identity is lost in to_homogeneous().

The main attention path still uses PyTorch SDPA. This PR only adds sparse relation logit bias before SDPA.

Where is the documentation for this feature?: N/A for this draft. I can add docs/changelog notes once we settle the final public interface names.

Did you add automated tests or write a test plan?

Added unit coverage in:

  • tests/unit/nn/graph_transformer_test.py
  • tests/unit/transforms/graph_transformer_test.py

Local checks run:

  • python3 -m py_compile gigl/nn/graph_transformer.py gigl/transforms/graph_transformer.py tests/unit/nn/graph_transformer_test.py tests/unit/transforms/graph_transformer_test.py
  • .venv/bin/ruff check --fix --config pyproject.toml gigl/nn/graph_transformer.py gigl/transforms/graph_transformer.py tests/unit/nn/graph_transformer_test.py tests/unit/transforms/graph_transformer_test.py
  • .venv/bin/ruff format --config pyproject.toml gigl/nn/graph_transformer.py gigl/transforms/graph_transformer.py tests/unit/nn/graph_transformer_test.py tests/unit/transforms/graph_transformer_test.py
  • git diff --check
  • Focused smoke script covering bilinear sparse bias, invalid relation IDs, transform relation indices, and encoder zero-init parity.

Comment thread gigl/nn/graph_transformer.py Outdated
self-attention path. ``"edge_type_bilinear"`` adds a learned
per-edge-type bilinear term for sampled directed graph edges. This
changes attention weights, not value/message content.
relation_value_mode: Optional relation-aware value augmentation strategy.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline that we can leave relation_value_mode out of this PR.

@mkolodner-sc mkolodner-sc left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants