diff --git a/src/shogun/base/SGObject.cpp b/src/shogun/base/SGObject.cpp index fe4241ae110..8636e27a8af 100644 --- a/src/shogun/base/SGObject.cpp +++ b/src/shogun/base/SGObject.cpp @@ -350,17 +350,25 @@ std::string SGObject::get_description(std::string_view name) const void SGObject::build_gradient_parameter_dictionary(std::map>& dict) { - for (auto& param: self->filter(ParameterProperties::GRADIENT)) - dict[{param.first.name(), std::make_shared(param.second)}] = shared_from_this(); + for (auto& param: self->filter(ParameterProperties::GRADIENT)) + dict[{param.first.name(), std::make_shared(param.second)}] = shared_from_this(); - for (const auto& param: self->filter(ParameterProperties::HYPER)) + InterfaceVisitor iv; + for (const auto& param: self->filter(ParameterProperties::HYPER)) { - if (auto child = sgo_details::get_by_tag(shared_from_this(), param.first.name(), sgo_details::GetByName())) - child->build_gradient_parameter_dictionary(dict); - else if (auto child = get(param.first.name(), std::nothrow)) - child->build_gradient_parameter_dictionary(dict); - else - SG_DEBUG("Parameter {} is not a SGObject. Skipping...", param.first.name().c_str()) + try + { + param.second.get_value().visit_with(&iv); + if (iv.value) + { + iv.value->build_gradient_parameter_dictionary(dict); + iv.value.reset(); // need to reset it otherwise iv should be created after each iter + } + } + catch (std::logic_error& e) + { + SG_DEBUG("Parameter {} is not a SGObject. Skipping...", param.first.name().c_str()); + } } } @@ -714,21 +722,44 @@ void SGObject::init_auto_params() std::shared_ptr SGObject::get(std::string_view name, index_t index) const { - auto result = sgo_details::get_by_tag(shared_from_this(), name, sgo_details::GetByNameIndex(index)); - if (!result && has(name)) + BaseTag tag(name); + if (self->has(tag)) { - error( - "Cannot get array parameter {}::{}[{}] of type {} as object.", - get_name(), name.data(), index, - self->map[BaseTag(name)].get_value().type().c_str()); + try + { + InterfaceVisitor iv; + iv.index = index; + self->get(tag).get_value().visit_with(&iv); + return std::move(iv.value); + } + catch (...) { /* handle it below... */ } } - return result; + + error( + "Cannot get array parameter {}::{}[{}] of type {} as object.", + get_name(), name.data(), index, + self->map[BaseTag(name)].get_value().type().c_str()); + + return nullptr; } std::shared_ptr SGObject::get(std::string_view name, std::nothrow_t) const noexcept { - return sgo_details::get_by_tag(shared_from_this(), name, sgo_details::GetByName()); + BaseTag tag(name); + if (!self->has(tag)) + return nullptr; + + try + { + InterfaceVisitor iv; + self->get(tag).get_value().visit_with(&iv); + return std::move(iv.value); + } + catch(...) + { + return nullptr; + } } std::shared_ptr SGObject::get(std::string_view name) const noexcept(false) diff --git a/src/shogun/base/SGObject.h b/src/shogun/base/SGObject.h index 575aac2bf39..dd08393eb8a 100644 --- a/src/shogun/base/SGObject.h +++ b/src/shogun/base/SGObject.h @@ -50,18 +50,6 @@ class ObservedValueTemplated; class Serializer; } -#ifndef SWIG -#ifndef DOXYGEN_SHOULD_SKIP_THIS - namespace sgo_details - { - template - bool dispatch_array_type( - const std::shared_ptr& obj, std::string_view name, - T2&& lambda); - } // namespace sgo_details -#endif // DOXYGEN_SHOULD_SKIP_THIS -#endif // SWIG - using stringToEnumMapType = std::unordered_map>; /******************************************************************************* @@ -130,6 +118,12 @@ SG_FORCED_INLINE const char* convert_string_to_char(const char* name) */ class SGObject: public std::enable_shared_from_this { + struct InterfaceVisitor + { + index_t index; + std::shared_ptr value; + }; + public: /** Definition of observed subject */ typedef rxcpp::subjects::subject> SGSubject; @@ -270,7 +264,7 @@ class SGObject: public std::enable_shared_from_this * @param name name of the parameter * @return true if the parameter exists with the input name and type */ - template ::value>* = nullptr> + template >* = nullptr> bool has(std::string_view name) const noexcept(true) { BaseTag tag(name); @@ -280,7 +274,7 @@ class SGObject: public std::enable_shared_from_this return value.has_type(); } - template ::value>* = nullptr> + template >* = nullptr> bool has(std::string_view name) const noexcept(true) { BaseTag tag(name); @@ -373,7 +367,7 @@ class SGObject: public std::enable_shared_from_this * @param value value of the parameter */ template ::value>, + class X = typename std::enable_if_t>, class Z = void> #ifdef SWIG void put(const std::string& name, std::shared_ptr value) @@ -391,7 +385,7 @@ class SGObject: public std::enable_shared_from_this * @param value value of the parameter */ template ::value>> + class X = typename std::enable_if_t>> #ifdef SWIG void add(const std::string& name, std::shared_ptr value) #else @@ -419,26 +413,22 @@ class SGObject: public std::enable_shared_from_this * @return desired element */ template ::value>> + class X = typename std::enable_if_t>> std::shared_ptr get(std::string_view name, index_t index, std::nothrow_t) const { - std::shared_ptr result; - - auto get_lambda = [&index, &result](auto& array) { - result = array.at(index); - }; - if (sgo_details::dispatch_array_type(shared_from_this(), name, get_lambda)) + Tag>> tag(name); + if (has(tag)) { - ASSERT(result); - // guard against mixed types in the array - return result->as(); + InterfaceVisitor iv; + iv.index = index; + get(tag).get_value().visit_with(&iv); + return std::move(iv.value->as()); } - return nullptr; } template ::value>> + class X = typename std::enable_if_t>> std::shared_ptr get(std::string_view name, index_t index) const { auto result = this->get(name, index, std::nothrow); @@ -475,7 +465,7 @@ class SGObject: public std::enable_shared_from_this */ std::shared_ptr get(std::string_view name, std::nothrow_t) const noexcept; #endif - + /** Untyped getter for an object array class parameter, identified by a name * and an index. * Will attempt to get specified object of appropriate internal type. @@ -519,7 +509,7 @@ class SGObject: public std::enable_shared_from_this * @param _tag name and type information of parameter * @return value of the parameter identified by the input tag */ - template ::value && !is_sg_base::value>* = nullptr> + template ::value && !is_sg_base_v>* = nullptr> T get(const Tag& _tag) const noexcept(false) { const Any value = get_parameter(_tag).get_value(); @@ -539,7 +529,7 @@ class SGObject: public std::enable_shared_from_this return any_cast(value); } - template ::value>* = nullptr> + template >* = nullptr> std::shared_ptr get(const Tag& _tag) const noexcept(false) { const Any value = get_parameter(_tag).get_value(); @@ -590,7 +580,7 @@ class SGObject: public std::enable_shared_from_this * @param name name of the parameter * @return value of the parameter corresponding to the input name and type */ - template ::value>> + template >> #ifdef SWIG T get(const std::string& name) const noexcept(false) #else @@ -621,7 +611,7 @@ class SGObject: public std::enable_shared_from_this } #ifndef SWIG - template ::value>* = nullptr> + template >* = nullptr> std::shared_ptr get(std::string_view name) const noexcept(false) { Tag tag(name); @@ -783,6 +773,29 @@ class SGObject: public std::enable_shared_from_this } protected: + template + void register_interface_visitor() const + { + // only register a visitor with interface visitor if T + // is not a type that we are interested in, i.e. + // dont just register empty lambdas for numbers types etc + if constexpr(is_sg_base_v || traits::is_vector_v) + { + Any::register_visitor( + [] (auto value, auto visitor) { + if constexpr (is_sg_base_v) + { + visitor->value = value; + } + else if constexpr (traits::is_vector_v) + { + if constexpr (is_sg_base_v) + visitor->value = value.at(visitor->index); + } + }); + } + } + /** Registers a class parameter which is identified by a tag. * This enables the parameter to be modified by put() and retrieved by * get(). @@ -825,6 +838,7 @@ class SGObject: public std::enable_shared_from_this { BaseTag tag(name); create_parameter(tag, AnyParameter(make_any_ref(value), properties)); + register_interface_visitor(); } /** Puts a pointer to some parameter into the parameter map. @@ -850,6 +864,7 @@ class SGObject: public std::enable_shared_from_this tag, AnyParameter( make_any_ref(value), properties, std::move(auto_init))); + register_interface_visitor(); } #ifndef SWIG @@ -883,6 +898,7 @@ class SGObject: public std::enable_shared_from_this constrain_function.run(casted_val, result); return result; })); + register_interface_visitor(); } #endif @@ -902,6 +918,7 @@ class SGObject: public std::enable_shared_from_this BaseTag tag(name); create_parameter( tag, AnyParameter(make_any_ref(value, len), properties)); + register_interface_visitor(); } /** Puts a pointer to some 2d parameter array (i.e. a matrix) into the @@ -922,6 +939,7 @@ class SGObject: public std::enable_shared_from_this BaseTag tag(name); create_parameter( tag, AnyParameter(make_any_ref(value, rows, cols), properties)); + register_interface_visitor(); } #ifndef SWIG @@ -1234,82 +1252,5 @@ std::shared_ptr make_clone(std::shared_ptr orig, ParameterProp ASSERT(clone); return std::static_pointer_cast(clone); } - -#ifndef SWIG -#ifndef DOXYGEN_SHOULD_SKIP_THIS -namespace sgo_details -{ - template - bool dispatch_array_type( - const std::shared_ptr& obj, std::string_view name, - T2&& lambda) - { - Tag>> tag_vector(name); - if (obj->has(tag_vector)) - { - auto dispatched = obj->get(tag_vector); - lambda(dispatched); - return true; - } - return false; - } - - struct GetByName - { - }; - - struct GetByNameIndex - { - GetByNameIndex(index_t index) : m_index(index) {} - index_t m_index; - }; - - template - std::shared_ptr get_if_possible(const std::shared_ptr& obj, std::string_view name, GetByName) - { - return obj->has(name) ? obj->get(name) : nullptr; - } - - template - std::shared_ptr get_if_possible(const std::shared_ptr& obj, std::string_view name, GetByNameIndex how) - { - std::shared_ptr result = nullptr; - result = obj->get(name, how.m_index, std::nothrow); - return result; - } - - template - std::shared_ptr get_dispatch_all_base_types(const std::shared_ptr& obj, std::string_view name, - T&& how) - { - if (auto result = get_if_possible(obj, name, how)) - return result; - if (auto result = get_if_possible(obj, name, how)) - return result; - if (auto result = get_if_possible(obj, name, how)) - return result; - if (auto result = get_if_possible(obj, name, how)) - return result; - if (auto result = get_if_possible(obj, name, how)) - return result; - if (auto result = get_if_possible(obj, name, how)) - return result; - if (auto result = get_if_possible(obj, name, how)) - return result; - - return nullptr; - } - - template - std::shared_ptr get_by_tag(const std::shared_ptr& obj, std::string_view name, - T&& how) - { - return get_dispatch_all_base_types(obj, name, how); - } -} // namespace sgo_details - -#endif //DOXYGEN_SHOULD_SKIP_THIS -#endif //SWIG - } #endif // __SGOBJECT_H__ diff --git a/src/shogun/base/base_types.h b/src/shogun/base/base_types.h index 54e8262f719..b50a1fd5f02 100644 --- a/src/shogun/base/base_types.h +++ b/src/shogun/base/base_types.h @@ -35,33 +35,15 @@ namespace shogun class Tokenizer; class CombinationRule; - // type trait to enable certain methods only for shogun base types - // FIXME: use sg_interface to populate this trait - template - struct is_sg_base - : std::integral_constant< - bool, std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value> - { - }; + template + struct type_list{}; + + using sg_inferface = type_list; template struct is_string @@ -100,22 +82,12 @@ namespace shogun { }; - template - struct type_list{}; - template struct type_holder { typedef T type; }; - using sg_inferface = type_list; - template constexpr auto find_base(type_list<>) { @@ -124,30 +96,47 @@ namespace shogun template constexpr auto find_base(type_list) { - if constexpr (std::is_base_of_v) - return type_holder{}; - else - return find_base(type_list{}); + if constexpr (std::is_base_of_v) + return type_holder{}; + else + return find_base(type_list{}); } template using base_type = typename decltype(find_base(sg_inferface{}))::type; template - struct remove_shared_ptr + struct remove_shared_ptr { using type = T; }; template struct remove_shared_ptr> - { - using type = T; + { + using type = T; }; template using remove_shared_ptr_t = typename remove_shared_ptr::type; + template + constexpr auto is_sg_base(type_list<>) + { + return std::false_type{}; + } + + template + constexpr auto is_sg_base(type_list) { + if constexpr (std::is_same_v) + return std::true_type{}; + else + return is_sg_base(type_list{}); + } + + template + inline constexpr bool is_sg_base_v = is_sg_base>(sg_inferface{}); + } // namespace shogun #endif // BASE_TYPES__H diff --git a/src/shogun/util/traits.h b/src/shogun/util/traits.h index c4131bcbaac..a4e06fcc809 100644 --- a/src/shogun/util/traits.h +++ b/src/shogun/util/traits.h @@ -159,7 +159,22 @@ namespace shogun struct is_shared_ptr> : std::true_type { }; - + + template + struct is_vector : std::false_type + { + using type = T; + }; + + template + struct is_vector> : std::true_type + { + using type = std::vector ; + }; + + template + inline constexpr bool is_vector_v = is_vector::value; + #endif // DOXYGEN_SHOULD_SKIP_THIS } // namespace traits } // namespace shogun