1- // Copyright (c) 2022 Google LLC
1+ // Copyright (c) 2025 Epic Games, Inc.
22//
33// Licensed under the Apache License, Version 2.0 (the "License");
44// you may not use this file except in compliance with the License.
@@ -30,11 +30,20 @@ constexpr uint32_t kOpEntryPointInOperandInterface = 3;
3030constexpr uint32_t kOpVariableStorageClassInOperandIndex = 0 ;
3131constexpr uint32_t kOpTypeArrayElemTypeInOperandIndex = 0 ;
3232constexpr uint32_t kOpTypeArrayLengthInOperandIndex = 1 ;
33+ constexpr uint32_t kOpTypeVectorComponentCountInOperandIndex = 1 ;
3334constexpr uint32_t kOpTypeMatrixColCountInOperandIndex = 1 ;
3435constexpr uint32_t kOpTypeMatrixColTypeInOperandIndex = 0 ;
3536constexpr uint32_t kOpTypePtrTypeInOperandIndex = 1 ;
3637constexpr uint32_t kOpConstantValueInOperandIndex = 0 ;
3738
39+ // Get the component count of the OpTypeVector |vector_type|.
40+ uint32_t GetVectorComponentCount (Instruction* vector_type) {
41+ assert (vector_type->opcode () == spv::Op::OpTypeVector);
42+ uint32_t component_count =
43+ vector_type->GetSingleWordInOperand (kOpTypeVectorComponentCountInOperandIndex );
44+ return component_count;
45+ }
46+
3847// Get the length of the OpTypeArray |array_type|.
3948uint32_t GetArrayLength (analysis::DefUseManager* def_use_mgr,
4049 Instruction* array_type) {
@@ -223,6 +232,8 @@ Pass::Status AdvancedInterfaceVariableScalarReplacement::ProcessEntryPoint(
223232
224233 ReplaceInEntryPoint (&entry_point, replaced_interface_vars, scalar_vars);
225234
235+ context ()->InvalidateAnalysesExceptFor (IRContext::Analysis::kAnalysisNone );
236+
226237 return status;
227238}
228239
@@ -330,19 +341,46 @@ bool AdvancedInterfaceVariableScalarReplacement::ReplaceInterfaceVariable(
330341 // We are going to replace the access chain with either direct usage of the
331342 // replacement scalar variable, or a set of composite loads/stores.
332343
333- const Replacement* target =
344+ LookupResult result =
334345 LookupReplacement (access_chain, &replacement, var.extra_array_length );
335- if (!target ) {
346+ if (!result. replacement ) {
336347 // Error has been already logged by |LookupReplacement|.
337348 return false ;
338349 }
350+ const Replacement* target = result.replacement ;
339351
340352 if (!target->HasChildren () && var.extra_array_length == 0 ) {
341- // Replace with a direct use of the scalar variable.
342353 auto scalar = target->GetScalarVariable ();
343354 assert (scalar);
344- context ()->ReplaceAllUsesWith (access_chain->result_id (),
345- scalar->result_id ());
355+
356+ uint32_t replacement = 0 ;
357+ if (result.index >= 0 ) {
358+ // Our scalar is a vector and access chain in question targets a
359+ // specific component denoted by result.index.
360+ assert (target->GetVectorComponentCount () > 0 );
361+ // Replace with an access chain into a direct use of the scalar variable.
362+ uint32_t indirection_id = TakeNextId ();
363+ if (indirection_id == 0 ) {
364+ return false ;
365+ }
366+
367+ uint32_t vector_component_type_id = context ()->get_def_use_mgr ()->GetDef (target->GetTypeId ())->GetSingleWordInOperand (0 );
368+
369+ uint32_t index_id = context ()->get_constant_mgr ()->GetUIntConstId (result.index );
370+ Operand index_operand = {SPV_OPERAND_TYPE_ID, {index_id}};
371+ std::unique_ptr<Instruction> vector_access_chain =
372+ CreateAccessChain (context (), indirection_id, scalar,
373+ vector_component_type_id, index_operand);
374+ replacement = vector_access_chain->result_id ();
375+
376+ auto inst = access_chain->InsertBefore (std::move (vector_access_chain));
377+ inst->UpdateDebugInfoFrom (access_chain);
378+ get_def_use_mgr ()->AnalyzeInstDef (inst);
379+ } else {
380+ // Replace with a direct use of the scalar variable.
381+ replacement = scalar->result_id ();
382+ }
383+ context ()->ReplaceAllUsesWith (access_chain->result_id (), replacement);
346384 } else {
347385 // The current access chain's target is a composite, meaning that there
348386 // are other instructions using the pointer. We need to convert those to
@@ -732,7 +770,7 @@ bool AdvancedInterfaceVariableScalarReplacement::ReplaceStore(
732770 return true ;
733771}
734772
735- const AdvancedInterfaceVariableScalarReplacement::Replacement*
773+ AdvancedInterfaceVariableScalarReplacement::LookupResult
736774AdvancedInterfaceVariableScalarReplacement::LookupReplacement (
737775 Instruction* access_chain, const Replacement* root,
738776 uint32_t extra_array_length) {
@@ -744,37 +782,59 @@ AdvancedInterfaceVariableScalarReplacement::LookupReplacement(
744782 // array, hence we skip it when looking-up the rest.
745783 uint32_t start_index = extra_array_length == 0 ? 1 : 2 ;
746784
785+ uint32_t num_indices = access_chain->NumInOperands ();
786+
747787 // Finds the target replacement, which might be a scalar or nested
748788 // composite.
749- for (uint32_t i = start_index; i < access_chain-> NumInOperands () ; ++i) {
789+ for (uint32_t i = start_index; i < num_indices ; ++i) {
750790 uint32_t index_id = access_chain->GetSingleWordInOperand (i);
751791
752792 const analysis::Constant* index_constant =
753793 const_mgr->FindDeclaredConstant (index_id);
754794 if (!index_constant) {
755795 context ()->EmitErrorMessage (
756796 " Variable cannot be replaced: index is not constant" , access_chain);
757- return nullptr ;
797+ return {};
798+ }
799+
800+ // OpAccessChain treats indices as signed.
801+ int64_t index_value = index_constant->GetSignExtendedValue ();
802+
803+ // Very last index can target the vector type, which we
804+ // have as a scalar.
805+ if (i == num_indices - 1 ) {
806+ if (root->GetScalarVariable ()) {
807+ if (index_value < 0 ||
808+ index_value >=
809+ static_cast <int64_t >(root->GetVectorComponentCount ())) {
810+ // Out of bounds access, this is illegal IR.
811+ // Notice that OpAccessChain indexing is 0-based, so we should also
812+ // reject index == size-of-array.
813+ context ()->EmitErrorMessage (
814+ " Variable cannot be replaced: invalid index" , access_chain);
815+ return {};
816+ }
817+ // Current root is our replacement scalar - a vector, in fact.
818+ return {root, index_value};
819+ }
758820 }
759821
760822 assert (root->HasChildren ());
761823 const auto & children = root->GetChildren ();
762824
763- // OpAccessChain treats indices as signed.
764- int64_t index_value = index_constant->GetSignExtendedValue ();
765825 if (index_value < 0 ||
766826 index_value >= static_cast <int64_t >(children.size ())) {
767827 // Out of bounds access, this is illegal IR.
768828 // Notice that OpAccessChain indexing is 0-based, so we should also
769829 // reject index == size-of-array.
770830 context ()->EmitErrorMessage (" Variable cannot be replaced: invalid index" ,
771831 access_chain);
772- return nullptr ;
832+ return {} ;
773833 }
774834
775835 root = &children[index_value];
776836 }
777- return root;
837+ return { root} ;
778838}
779839
780840AdvancedInterfaceVariableScalarReplacement::Replacement
@@ -863,7 +923,12 @@ AdvancedInterfaceVariableScalarReplacement::CreateReplacementVariables(
863923 std::unique_ptr<Instruction> variable = CreateVariable (
864924 type->result_id (), storage_class, var.def , var.extra_array_length );
865925
866- node->SetSingleScalarVariable (variable.get ());
926+ uint32_t vector_component_count = 0 ;
927+ if (opcode == spv::Op::OpTypeVector) {
928+ vector_component_count = GetVectorComponentCount (type);
929+ }
930+
931+ node->SetSingleScalarVariable (variable.get (), vector_component_count);
867932 scalar_vars->push_back (variable.get ());
868933
869934 uint32_t var_id = variable->result_id ();
0 commit comments