Skip to content

Commit 921fc7c

Browse files
committed
add options to disable torch and ct2 detection
1 parent e81cda2 commit 921fc7c

File tree

6 files changed

+95
-31
lines changed

6 files changed

+95
-31
lines changed

desktop/qml/SettingsAdvancedPage.qml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,29 @@ ColumnLayout {
306306
}
307307
}
308308

309+
function setScanFlags() {
310+
var flags = Settings.ScanFlagNone
311+
if (scanFlagTorchCheckbox.checked) flags |= Settings.ScanFlagNoTorchCuda
312+
if (scanFlagCt2Checkbox.checked) flags |= Settings.ScanFlagNoCt2Cuda
313+
_settings.scan_flags = flags
314+
}
315+
316+
CheckBox {
317+
id: scanFlagTorchCheckbox
318+
319+
checked: _settings.scan_flags & Settings.ScanFlagNoTorchCuda
320+
text: qsTranslate("SettingsPage", "Disable detection of PyTorch")
321+
onCheckedChanged: setScanFlags()
322+
}
323+
324+
CheckBox {
325+
id: scanFlagCt2Checkbox
326+
327+
checked: _settings.scan_flags & Settings.ScanFlagNoCt2Cuda
328+
text: qsTranslate("SettingsPage", "Disable detection of CTranslate2")
329+
onCheckedChanged: setScanFlags()
330+
}
331+
309332
SectionLabel {
310333
visible: app.feature_fake_keyboard || ydoMessage.visible
311334
text: qsTranslate("SettingsPage", "Insert into active window")

src/py_executor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ void py_executor::loop() {
121121
return py_tools::libs_scan_type_t::off_all_disabled;
122122
}
123123
LOGF("invalid py scan mode");
124-
}());
124+
}(),
125+
settings::instance()->scan_flags());
125126
#else
126127
libs_availability = py_tools::libs_availability_t{};
127128
#endif

src/py_tools.cpp

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "cpu_tools.hpp"
2222
#include "logger.hpp"
23+
#include "settings.h"
2324

2425
#ifdef USE_PYTHON_MODULE
2526
#include "module_tools.hpp"
@@ -61,7 +62,8 @@ std::ostream& operator<<(std::ostream& os,
6162
}
6263

6364
namespace py_tools {
64-
libs_availability_t libs_availability(libs_scan_type_t scan_type) {
65+
libs_availability_t libs_availability(libs_scan_type_t scan_type,
66+
unsigned int scan_flags) {
6567
// run only in py thread
6668

6769
libs_availability_t availability{};
@@ -114,39 +116,50 @@ libs_availability_t libs_availability(libs_scan_type_t scan_type) {
114116
LOGD("python version check py error: " << err.what());
115117
}
116118

117-
if (cpu_tools::cpuinfo().feature_flags & cpu_tools::feature_flags_t::avx) {
118-
try {
119-
LOGD("checking: torch cuda");
120-
auto torch_cuda = py::module_::import("torch.cuda");
121-
auto torch_ver = py::module_::import("torch.version");
122-
if (torch_cuda.attr("is_available")().cast<bool>()) {
123-
try {
124-
auto cuda_ver = torch_ver.attr("cuda").cast<std::string>();
125-
LOGD("torch cuda version: " << cuda_ver);
126-
availability.torch_cuda = !cuda_ver.empty();
127-
} catch ([[maybe_unused]] const py::cast_error& err) {
128-
}
129-
try {
130-
auto hip_ver = torch_ver.attr("hip").cast<std::string>();
131-
LOGD("torch hip version: " << hip_ver);
132-
availability.torch_hip = !hip_ver.empty();
133-
} catch ([[maybe_unused]] const py::cast_error& err) {
119+
if ((scan_flags & settings::ScanFlagNoTorchCuda) > 0) {
120+
LOGD("checking: torch cuda (skipped)");
121+
} else {
122+
if (cpu_tools::cpuinfo().feature_flags &
123+
cpu_tools::feature_flags_t::avx) {
124+
try {
125+
LOGD("checking: torch cuda");
126+
auto torch_cuda = py::module_::import("torch.cuda");
127+
auto torch_ver = py::module_::import("torch.version");
128+
if (torch_cuda.attr("is_available")().cast<bool>()) {
129+
try {
130+
auto cuda_ver =
131+
torch_ver.attr("cuda").cast<std::string>();
132+
LOGD("torch cuda version: " << cuda_ver);
133+
availability.torch_cuda = !cuda_ver.empty();
134+
} catch ([[maybe_unused]] const py::cast_error& err) {
135+
}
136+
try {
137+
auto hip_ver =
138+
torch_ver.attr("hip").cast<std::string>();
139+
LOGD("torch hip version: " << hip_ver);
140+
availability.torch_hip = !hip_ver.empty();
141+
} catch ([[maybe_unused]] const py::cast_error& err) {
142+
}
134143
}
144+
} catch (const std::exception& err) {
145+
LOGD("torch cuda check py error: " << err.what());
135146
}
136-
} catch (const std::exception& err) {
137-
LOGD("torch cuda check py error: " << err.what());
138147
}
139148
}
140149

141-
try {
142-
LOGD("checking: ctranslate2-cuda");
143-
auto ct2 = py::module_::import("ctranslate2");
144-
LOGD("ctranslate2 version: "
145-
<< ct2.attr("__version__").cast<std::string>());
146-
availability.ctranslate2_cuda =
147-
py::len(ct2.attr("get_supported_compute_types")("cuda")) > 0;
148-
} catch (const std::exception& err) {
149-
LOGD("ctranslate2-cuda check py error: " << err.what());
150+
if ((scan_flags & settings::ScanFlagNoCt2Cuda) > 0) {
151+
LOGD("checking: ctranslate2-cuda (skipped)");
152+
} else {
153+
try {
154+
LOGD("checking: ctranslate2-cuda");
155+
auto ct2 = py::module_::import("ctranslate2");
156+
LOGD("ctranslate2 version: "
157+
<< ct2.attr("__version__").cast<std::string>());
158+
availability.ctranslate2_cuda =
159+
py::len(ct2.attr("get_supported_compute_types")("cuda")) > 0;
160+
} catch (const std::exception& err) {
161+
LOGD("ctranslate2-cuda check py error: " << err.what());
162+
}
150163
}
151164

152165
if (scan_type == libs_scan_type_t::off_all_enabled ||

src/py_tools.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ struct libs_availability_t {
4848
bool uroman = false;
4949
};
5050

51-
libs_availability_t libs_availability(libs_scan_type_t scan_type);
51+
libs_availability_t libs_availability(libs_scan_type_t scan_type,
52+
unsigned int scan_flags);
5253
bool init_module();
5354
} // namespace py_tools
5455

src/settings.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2901,3 +2901,17 @@ void settings::set_fake_keyboard_delay(int value) {
29012901
emit fake_keyboard_delay_changed();
29022902
}
29032903
}
2904+
2905+
unsigned int settings::scan_flags() const {
2906+
return value(QStringLiteral("service/scan_flags"),
2907+
scan_flags_t::ScanFlagNone)
2908+
.toUInt();
2909+
}
2910+
2911+
void settings::set_scan_flags(unsigned int flags) {
2912+
if (scan_flags() != flags) {
2913+
setValue(QStringLiteral("service/scan_flags"), flags);
2914+
emit scan_flags_changed();
2915+
set_restart_required(true);
2916+
}
2917+
}

src/settings.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ class settings : public QSettings, public singleton<settings> {
368368
Q_PROPERTY(
369369
QString gpu_overrided_version READ gpu_overrided_version WRITE
370370
set_gpu_overrided_version NOTIFY gpu_overrided_version_changed)
371+
Q_PROPERTY(unsigned int scan_flags READ scan_flags WRITE set_scan_flags
372+
NOTIFY scan_flags_changed)
371373

372374
// engine options
373375
#define X(name) \
@@ -654,6 +656,13 @@ class settings : public QSettings, public singleton<settings> {
654656
};
655657
Q_ENUM(py_scan_mode_t)
656658

659+
enum scan_flags_t : unsigned int {
660+
ScanFlagNone = 0U,
661+
ScanFlagNoTorchCuda = 1U << 0U,
662+
ScanFlagNoCt2Cuda = 1U << 1U
663+
};
664+
Q_ENUM(scan_flags_t)
665+
657666
struct voice_profile_prompt_t {
658667
QString name;
659668
QString desc;
@@ -947,6 +956,8 @@ class settings : public QSettings, public singleton<settings> {
947956
void set_gpu_override_version(bool value);
948957
QString gpu_overrided_version();
949958
void set_gpu_overrided_version(QString new_value);
959+
unsigned int scan_flags() const;
960+
void set_scan_flags(unsigned int flags);
950961

951962
#define X(name) \
952963
bool name##_autolang_with_sup() const; \
@@ -1073,6 +1084,7 @@ class settings : public QSettings, public singleton<settings> {
10731084
void gpu_override_version_changed();
10741085
void gpu_overrided_version_changed();
10751086
void gpu_devices_changed();
1087+
void scan_flags_changed();
10761088

10771089
#define X(name) void name##_changed();
10781090
X(whispercpp)

0 commit comments

Comments
 (0)