Skip to content

Commit cad54db

Browse files
authored
Merge pull request #65 from necolab-ugr/sklearn-check
Scikit-learn version management for specific example scripts
2 parents 4bc46bc + 563b0b5 commit cad54db

File tree

5 files changed

+36
-8
lines changed

5 files changed

+36
-8
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ got 88 from PyObject*—you can resolve it by force-reinstalling both packages:
7070

7171
```bash
7272
pip uninstall scikit-learn numpy -y
73-
pip install scikit-learn==1.3.2 numpy
73+
pip install scikit-learn==1.5.0 numpy
7474
```
7575

7676
# Folder Structure

examples/EEG_AD/EEG_AD.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ def create_POCTEP_dataframe(data_path):
103103
return df
104104

105105
if __name__ == "__main__":
106+
# Check scikit-learn version
107+
if not tools.ensure_module('scikit-learn', 'scikit-learn==1.3.2'):
108+
print("Failed to install required scikit-learn version 1.3.2. Please install it manually.")
109+
106110
# Download simulation data and ML models
107111
if zenodo_dw_sim:
108112
print('\n--- Downloading simulation data and ML models from Zenodo.')

examples/LFP_developing_brain/LFP_developing_brain.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ def save_data(emp_data, method):
208208

209209

210210
if __name__ == "__main__":
211+
# Check scikit-learn version
212+
if not tools.ensure_module('scikit-learn', 'scikit-learn==1.3.2'):
213+
print("Failed to install required scikit-learn version 1.3.2. Please install it manually.")
214+
211215
# Download simulation data and ML models
212216
if zenodo_dw_sim:
213217
print('\n--- Downloading simulation data and ML models from Zenodo.')

ncpi/tools.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,31 @@ def ensure_module(module_name, package_name=None):
2828
# Check again after installation
2929
if importlib.util.find_spec(module_name) is None:
3030
raise ImportError(f"Installation failed: '{module_name}' not found after installation.")
31-
32-
return True # Module is available
33-
except subprocess.CalledProcessError:
34-
print(f"Error: Failed to install '{module_name}'. Please install it manually.")
31+
32+
# Check if installed version matches required version from package_name
33+
if package_name and '==' in package_name:
34+
required_version = package_name.split('==')[1]
35+
module = importlib.import_module(module_name)
36+
installed_version = getattr(module, '__version__', None)
37+
38+
if installed_version != required_version:
39+
print(f"Module '{module_name}' version {installed_version} found, but {required_version} required. Reinstalling...")
40+
# Uninstall current version
41+
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", module_name])
42+
# Install required version
43+
subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
44+
45+
# Clear cached modules
46+
for mod_name in list(sys.modules.keys()):
47+
if mod_name == module_name or mod_name.startswith(f"{module_name}."):
48+
del sys.modules[mod_name]
49+
50+
print(f"Module '{module_name}' reinstalled successfully with version {required_version}.")
51+
52+
return True
53+
54+
except subprocess.CalledProcessError as e:
55+
print(f"Error: {e} Failed to install '{module_name}'. Please install it manually.")
3556
except ImportError as e:
3657
print(f"Error: {e}")
3758
except Exception as e:
@@ -147,7 +168,6 @@ def download_zenodo_record(api_url, download_dir="zenodo_files", extract_tar=Tru
147168
print("If you want to re-download, please delete the directory.")
148169

149170

150-
151171
def timer(description=None):
152172
"""
153173
Decorator that measures and prints execution time of a function.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
requires-python = ">=3.10"
2424

2525
dependencies = [
26-
"scikit-learn==1.3.2",
26+
"scikit-learn==1.5.0",
2727
"sbi==0.24.0",
2828
"torch",
2929
"pycatch22",
@@ -77,4 +77,4 @@ documentation = "https://necolab-ugr.github.io/ncpi"
7777
[tool.setuptools.packages.find]
7878
where = ["."] # list of folders that contain the packages (["."] by default)
7979
include = ["*"] # list of packages to include (["*"] by default)
80-
exclude = ["tests", "tests.*"] # list of packages to exclude (["tests"] by default)
80+
exclude = ["tests", "tests.*"] # list of packages to exclude (["tests"] by default)

0 commit comments

Comments
 (0)