Skip to content

Conversation

@irexyc
Copy link
Collaborator

@irexyc irexyc commented Sep 9, 2025

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Please describe the motivation of this PR and the goal you want to achieve through this PR.

Modification

Please briefly describe what modification is made in this PR.

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit tests to ensure the correctness.
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@lzhangzz lzhangzz self-requested a review September 23, 2025 05:46
@lvhan028 lvhan028 marked this pull request as ready for review October 30, 2025 13:38
@lvhan028 lvhan028 changed the title [WIP] support context parallel support context parallel Oct 30, 2025
@lvhan028 lvhan028 added the enhancement New feature or request label Oct 30, 2025
@lvhan028
Copy link
Collaborator

lvhan028 commented Nov 3, 2025

may resolve the build error on windows platform

const int tp_rank_;
const DataType data_type_;
const bool debug_;
const bool is_driver_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider renaming is_driver_ to be more specific. The current name is vague - what exactly is it driving or controlling?

tp_rank_(model->tp_rank_),
data_type_(data_type),
debug_(isDebug()),
is_driver_(param.attn_tp_rank == 0 && param.attn_cp_rank == 0),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At current setting, this is the same as tp_rank_ == 0, is_driver_ is not needed.


self._postprocess_config(tm_model.tm_config, engine_config)

print(yaml.safe_dump(self.config_dict))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Control log level

// for context parallel, we use symm_alloc_ and both prefill and decode stage have reduce process
// w/o context parallel, we use common alloc and only decode stage has reduce process
// perhaps it would be more appropriate to put this buffer in the unified_attention_layer.
Allocator alloc = param_.attn_cp_size > 1 ? symm_alloc_ : core::Context::alloc(kDEVICE);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will create a new allocator which is not needed in the case. Use core::Context::device_alloc() to get the device allocator in current context.

if (qi_begin + qi < qi_end && ri == 0 && check_h(hi)) {
params.partial_M[index] = M;
params.partial_L[index] = L;
params.partial_ML[index * 2] = M;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make the partial_ML a pointer to float2 so that load / store can be vectorized.

@@ -0,0 +1,20 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "src/turbomind/models/llama/cp_utils.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move cp_utils.* to kernels/attention

}
}

int cp_quo, cp_rem;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use expressive names, e.g. local_ti and local_ti_rank

}

const int tile_count = cdiv(std::min(params.max_k_len, params.window_size), Kernel::CTA_S);
const int max_cp_k_len = (params.max_k_len + params.cp_size - 1) / params.cp_size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use cdiv

}();

const int tile_count = cdiv(std::min(params.max_k_len, params.window_size), Kernel::CTA_S);
const int max_cp_k_len = (params.max_k_len + params.cp_size - 1) / params.cp_size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use cdiv

const int qi = offset.y / CTA_H;
const int ti = history_len;

int cp_quo, cp_rem;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use expressive names

});
}

const bool separate_reduce = need_separate_reduce(cta_map.split_count());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code path can be removed.


Impl::Merge(frag_O, frag_M, frag_L, params.inv_sqrt_dh, storage);

if (params.sinks && iter_end == tile_count) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention sink should be applied to cp rank 0 ONLY

@lvhan028 lvhan028 merged commit 2bc6529 into InternLM:main Nov 19, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants