File tree Expand file tree Collapse file tree 1 file changed +9
-8
lines changed Expand file tree Collapse file tree 1 file changed +9
-8
lines changed Original file line number Diff line number Diff line change @@ -756,17 +756,18 @@ void ShardingUtil::ReshardParameters(
756756 std::vector<runtime::ComputationClient::DataPtr> data_to_reshard;
757757 std::vector<torch_xla::OpSharding> shardings_to_reshard;
758758
759+ std::vector<int64_t > denormalized_tile_assignment;
760+ auto sharding_spec = (*tensors)[0 ]->sharding_spec ();
761+ if (sharding_spec) {
762+ denormalized_tile_assignment = sharding_spec->sharding .GetDenormalizedTileAssignment ();
763+ }
759764 for (const auto & sharding : xla_input_shardings) {
760- for (const auto & data : *parameters) {
761- runtime::ComputationClient::DataPtr handle =
762- std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data);
763- auto computation_client_ptr = runtime::GetComputationClient ();
764- torch_xla::OpSharding torch_xla_opsharding =
765- (*computation_client_ptr)->GetDataSharding (handle).value ();
766- std::vector<int64_t > denormalized_tile_assignment =
767- torch_xla_opsharding.GetDenormalizedTileAssignment ();
765+ if (denormalized_tile_assignment.size () > 0 ){
768766 input_shardings.emplace_back (sharding, denormalized_tile_assignment);
769767 }
768+ else {
769+ input_shardings.emplace_back (sharding);
770+ }
770771 }
771772
772773 for (int i = 0 ; i < input_shardings.size (); ++i) {
You can’t perform that action at this time.
0 commit comments