Skip to content

Commit 7bf3754

Browse files
committed
normalize hard sigmoid activation name with brulee
1 parent 44a08b7 commit 7bf3754

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

R/mlp.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ keras_mlp <-
200200
{.val {activation}}."
201201
)
202202
}
203+
activation <- get_activation_fn(activation)
203204

204205
if (penalty > 0 & dropout > 0) {
205206
cli::cli_abort("Please use either dropout or weight decay.", call = NULL)
@@ -351,7 +352,7 @@ mlp_num_weights <- function(p, hidden_units, classes) {
351352
}
352353

353354
allowed_keras_activation <-
354-
c("elu", "exponential", "gelu", "hard_sigmoid", "linear", "relu", "selu",
355+
c("elu", "exponential", "gelu", "hardsigmoid", "linear", "relu", "selu",
355356
"sigmoid", "softmax", "softplus", "softsign", "swish", "tanh")
356357

357358
#' Activation functions for neural networks in keras
@@ -363,6 +364,13 @@ keras_activations <- function() {
363364
allowed_keras_activation
364365
}
365366

367+
get_activation_fn <- function(arg, ...) {
368+
if (arg == "hardsigmoid") {
369+
arg <- "hard_sigmoid"
370+
}
371+
arg
372+
}
373+
366374
## -----------------------------------------------------------------------------
367375

368376
#' @importFrom purrr map

tests/testthat/_snaps/mlp_keras.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,13 @@
66
Error:
77
! object 'novar' not found
88

9+
# all keras activation functions
10+
11+
Code
12+
mlp(mode = "classification", hidden_units = 2, penalty = 0.01, epochs = 2,
13+
activation = "invalid") %>% set_engine("keras", verbose = 0) %>% parsnip::fit(
14+
Class ~ A + B, data = modeldata::two_class_dat)
15+
Condition
16+
Error in `parsnip::keras_mlp()`:
17+
! `activation` should be one of: elu, exponential, gelu, hardsigmoid, linear, relu, selu, sigmoid, softmax, softplus, softsign, swish, and tanh, not "invalid".
18+

0 commit comments

Comments
 (0)