Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 48 additions & 17 deletions src/shogun/base/SGObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,17 +350,25 @@ std::string SGObject::get_description(std::string_view name) const

void SGObject::build_gradient_parameter_dictionary(std::map<Parameters::value_type, std::shared_ptr<SGObject>>& dict)
{
for (auto& param: self->filter(ParameterProperties::GRADIENT))
dict[{param.first.name(), std::make_shared<const AnyParameter>(param.second)}] = shared_from_this();
for (auto& param: self->filter(ParameterProperties::GRADIENT))
dict[{param.first.name(), std::make_shared<const AnyParameter>(param.second)}] = shared_from_this();

for (const auto& param: self->filter(ParameterProperties::HYPER))
InterfaceVisitor iv;
for (const auto& param: self->filter(ParameterProperties::HYPER))
{
if (auto child = sgo_details::get_by_tag(shared_from_this(), param.first.name(), sgo_details::GetByName()))
child->build_gradient_parameter_dictionary(dict);
else if (auto child = get(param.first.name(), std::nothrow))
child->build_gradient_parameter_dictionary(dict);
else
SG_DEBUG("Parameter {} is not a SGObject. Skipping...", param.first.name().c_str())
try
{
param.second.get_value().visit_with(&iv);
if (iv.value)
{
iv.value->build_gradient_parameter_dictionary(dict);
iv.value.reset(); // need to reset it otherwise iv should be created after each iter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

}
}
catch (std::logic_error& e)
{
SG_DEBUG("Parameter {} is not a SGObject. Skipping...", param.first.name().c_str());
}
}
}

Expand Down Expand Up @@ -714,21 +722,44 @@ void SGObject::init_auto_params()

std::shared_ptr<SGObject> SGObject::get(std::string_view name, index_t index) const
{
auto result = sgo_details::get_by_tag(shared_from_this(), name, sgo_details::GetByNameIndex(index));
if (!result && has(name))
BaseTag tag(name);
if (self->has(tag))
{
error(
"Cannot get array parameter {}::{}[{}] of type {} as object.",
get_name(), name.data(), index,
self->map[BaseTag(name)].get_value().type().c_str());
try
{
InterfaceVisitor iv;
iv.index = index;
self->get(tag).get_value().visit_with(&iv);
return std::move(iv.value);
}
catch (...) { /* handle it below... */ }
}
return result;

error(
"Cannot get array parameter {}::{}[{}] of type {} as object.",
get_name(), name.data(), index,
self->map[BaseTag(name)].get_value().type().c_str());

return nullptr;
}

std::shared_ptr<SGObject> SGObject::get(std::string_view name, std::nothrow_t) const
noexcept
{
return sgo_details::get_by_tag(shared_from_this(), name, sgo_details::GetByName());
BaseTag tag(name);
if (!self->has(tag))
return nullptr;

try
{
InterfaceVisitor iv;
self->get(tag).get_value().visit_with(&iv);
return std::move(iv.value);
}
catch(...)
{
return nullptr;
}
}

std::shared_ptr<SGObject> SGObject::get(std::string_view name) const noexcept(false)
Expand Down
161 changes: 51 additions & 110 deletions src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,6 @@ class ObservedValueTemplated;
class Serializer;
}

#ifndef SWIG
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace sgo_details
{
template <typename T1, typename T2>
bool dispatch_array_type(
const std::shared_ptr<const SGObject>& obj, std::string_view name,
T2&& lambda);
} // namespace sgo_details
#endif // DOXYGEN_SHOULD_SKIP_THIS
#endif // SWIG

using stringToEnumMapType = std::unordered_map<std::string_view, std::unordered_map<std::string_view, machine_int_t>>;

/*******************************************************************************
Expand Down Expand Up @@ -130,6 +118,12 @@ SG_FORCED_INLINE const char* convert_string_to_char(const char* name)
*/
class SGObject: public std::enable_shared_from_this<SGObject>
{
struct InterfaceVisitor
{
index_t index;
std::shared_ptr<SGObject> value;
};

public:
/** Definition of observed subject */
typedef rxcpp::subjects::subject<std::shared_ptr<ObservedValue>> SGSubject;
Expand Down Expand Up @@ -270,7 +264,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
* @param name name of the parameter
* @return true if the parameter exists with the input name and type
*/
template <typename T, typename std::enable_if_t<!is_sg_base<T>::value>* = nullptr>
template <typename T, typename std::enable_if_t<!is_sg_base_v<T>>* = nullptr>
bool has(std::string_view name) const noexcept(true)
{
BaseTag tag(name);
Expand All @@ -280,7 +274,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
return value.has_type<T>();
}

template <typename T, typename std::enable_if_t<is_sg_base<T>::value>* = nullptr>
template <typename T, typename std::enable_if_t<is_sg_base_v<T>>* = nullptr>
bool has(std::string_view name) const noexcept(true)
{
BaseTag tag(name);
Expand Down Expand Up @@ -373,7 +367,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
* @param value value of the parameter
*/
template <class T,
class X = typename std::enable_if_t<is_sg_base<T>::value>,
class X = typename std::enable_if_t<is_sg_base_v<T>>,
class Z = void>
#ifdef SWIG
void put(const std::string& name, std::shared_ptr<T> value)
Expand All @@ -391,7 +385,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
* @param value value of the parameter
*/
template <class T,
class X = typename std::enable_if_t<is_sg_base<T>::value>>
class X = typename std::enable_if_t<is_sg_base_v<T>>>
#ifdef SWIG
void add(const std::string& name, std::shared_ptr<T> value)
#else
Expand Down Expand Up @@ -419,26 +413,22 @@ class SGObject: public std::enable_shared_from_this<SGObject>
* @return desired element
*/
template <class T,
class X = typename std::enable_if_t<is_sg_base<T>::value>>
class X = typename std::enable_if_t<is_sg_base_v<T>>>
std::shared_ptr<T> get(std::string_view name, index_t index, std::nothrow_t) const
{
std::shared_ptr<SGObject> result;

auto get_lambda = [&index, &result](auto& array) {
result = array.at(index);
};
if (sgo_details::dispatch_array_type<T>(shared_from_this(), name, get_lambda))
Tag<std::vector<std::shared_ptr<T>>> tag(name);
if (has<T>(tag))
{
ASSERT(result);
// guard against mixed types in the array
return result->as<T>();
InterfaceVisitor iv;
iv.index = index;
get<T>(tag).get_value().visit_with(&iv);
return std::move(iv.value->as<T>());
}

return nullptr;
}

template <class T,
class X = typename std::enable_if_t<is_sg_base<T>::value>>
class X = typename std::enable_if_t<is_sg_base_v<T>>>
std::shared_ptr<T> get(std::string_view name, index_t index) const
{
auto result = this->get<T>(name, index, std::nothrow);
Expand Down Expand Up @@ -475,7 +465,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
*/
std::shared_ptr<SGObject> get(std::string_view name, std::nothrow_t) const noexcept;
#endif

/** Untyped getter for an object array class parameter, identified by a name
* and an index.
* Will attempt to get specified object of appropriate internal type.
Expand Down Expand Up @@ -519,7 +509,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
* @param _tag name and type information of parameter
* @return value of the parameter identified by the input tag
*/
template <typename T, typename std::enable_if_t<!is_string<T>::value && !is_sg_base<T>::value>* = nullptr>
template <typename T, typename std::enable_if_t<!is_string<T>::value && !is_sg_base_v<T>>* = nullptr>
T get(const Tag<T>& _tag) const noexcept(false)
{
const Any value = get_parameter(_tag).get_value();
Expand All @@ -539,7 +529,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
return any_cast<T>(value);
}

template <typename T, typename std::enable_if_t<is_sg_base<T>::value>* = nullptr>
template <typename T, typename std::enable_if_t<is_sg_base_v<T>>* = nullptr>
std::shared_ptr<T> get(const Tag<T>& _tag) const noexcept(false)
{
const Any value = get_parameter(_tag).get_value();
Expand Down Expand Up @@ -590,7 +580,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
* @param name name of the parameter
* @return value of the parameter corresponding to the input name and type
*/
template <typename T, class X = typename std::enable_if_t<!is_sg_base<T>::value>>
template <typename T, class X = typename std::enable_if_t<!is_sg_base_v<T>>>
#ifdef SWIG
T get(const std::string& name) const noexcept(false)
#else
Expand Down Expand Up @@ -621,7 +611,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
}

#ifndef SWIG
template <typename T, typename std::enable_if_t<is_sg_base<T>::value>* = nullptr>
template <typename T, typename std::enable_if_t<is_sg_base_v<T>>* = nullptr>
std::shared_ptr<T> get(std::string_view name) const noexcept(false)
{
Tag<T> tag(name);
Expand Down Expand Up @@ -783,6 +773,29 @@ class SGObject: public std::enable_shared_from_this<SGObject>
}

protected:
template<typename T>
void register_interface_visitor() const
{
// only register a visitor with interface visitor if T
// is not a type that we are interested in, i.e.
// dont just register empty lambdas for numbers types etc
if constexpr(is_sg_base_v<T> || traits::is_vector_v<T>)
{
Any::register_visitor<T, InterfaceVisitor>(
[] (auto value, auto visitor) {
if constexpr (is_sg_base_v<T>)
{
visitor->value = value;
}
else if constexpr (traits::is_vector_v<T>)
{
if constexpr (is_sg_base_v<typename T::value_type>)
visitor->value = value.at(visitor->index);
}
});
}
}

/** Registers a class parameter which is identified by a tag.
* This enables the parameter to be modified by put() and retrieved by
* get().
Expand Down Expand Up @@ -825,6 +838,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
{
BaseTag tag(name);
create_parameter(tag, AnyParameter(make_any_ref(value), properties));
register_interface_visitor<T>();
}

/** Puts a pointer to some parameter into the parameter map.
Expand All @@ -850,6 +864,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
tag,
AnyParameter(
make_any_ref(value), properties, std::move(auto_init)));
register_interface_visitor<T>();
}

#ifndef SWIG
Expand Down Expand Up @@ -883,6 +898,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
constrain_function.run(casted_val, result);
return result;
}));
register_interface_visitor<T1>();
}
#endif

Expand All @@ -902,6 +918,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
BaseTag tag(name);
create_parameter(
tag, AnyParameter(make_any_ref(value, len), properties));
register_interface_visitor<T>();
}

/** Puts a pointer to some 2d parameter array (i.e. a matrix) into the
Expand All @@ -922,6 +939,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
BaseTag tag(name);
create_parameter(
tag, AnyParameter(make_any_ref(value, rows, cols), properties));
register_interface_visitor<T>();
}

#ifndef SWIG
Expand Down Expand Up @@ -1234,82 +1252,5 @@ std::shared_ptr<const T> make_clone(std::shared_ptr<const T> orig, ParameterProp
ASSERT(clone);
return std::static_pointer_cast<const T>(clone);
}

#ifndef SWIG
#ifndef DOXYGEN_SHOULD_SKIP_THIS
namespace sgo_details
{
template <typename T1, typename T2>
bool dispatch_array_type(
const std::shared_ptr<const SGObject>& obj, std::string_view name,
T2&& lambda)
{
Tag<std::vector<std::shared_ptr<T1>>> tag_vector(name);
if (obj->has(tag_vector))
{
auto dispatched = obj->get(tag_vector);
lambda(dispatched);
return true;
}
return false;
}

struct GetByName
{
};

struct GetByNameIndex
{
GetByNameIndex(index_t index) : m_index(index) {}
index_t m_index;
};

template <typename T>
std::shared_ptr<SGObject> get_if_possible(const std::shared_ptr<const SGObject>& obj, std::string_view name, GetByName)
{
return obj->has<T>(name) ? obj->get<T>(name) : nullptr;
}

template <typename T>
std::shared_ptr<SGObject> get_if_possible(const std::shared_ptr<const SGObject>& obj, std::string_view name, GetByNameIndex how)
{
std::shared_ptr<SGObject> result = nullptr;
result = obj->get<T>(name, how.m_index, std::nothrow);
return result;
}

template<typename T>
std::shared_ptr<SGObject> get_dispatch_all_base_types(const std::shared_ptr<const SGObject>& obj, std::string_view name,
T&& how)
{
if (auto result = get_if_possible<Kernel>(obj, name, how))
return result;
if (auto result = get_if_possible<Features>(obj, name, how))
return result;
if (auto result = get_if_possible<Machine>(obj, name, how))
return result;
if (auto result = get_if_possible<Labels>(obj, name, how))
return result;
if (auto result = get_if_possible<EvaluationResult>(obj, name, how))
return result;
if (auto result = get_if_possible<LikelihoodModel>(obj, name, how))
return result;
if (auto result = get_if_possible<MeanFunction>(obj, name, how))
return result;

return nullptr;
}

template<class T>
std::shared_ptr<SGObject> get_by_tag(const std::shared_ptr<const SGObject>& obj, std::string_view name,
T&& how)
{
return get_dispatch_all_base_types(obj, name, how);
}
} // namespace sgo_details

#endif //DOXYGEN_SHOULD_SKIP_THIS
#endif //SWIG

}
#endif // __SGOBJECT_H__
Loading