|
20 | 20 |
|
21 | 21 | #include "cpu_tools.hpp" |
22 | 22 | #include "logger.hpp" |
| 23 | +#include "settings.h" |
23 | 24 |
|
24 | 25 | #ifdef USE_PYTHON_MODULE |
25 | 26 | #include "module_tools.hpp" |
@@ -61,7 +62,8 @@ std::ostream& operator<<(std::ostream& os, |
61 | 62 | } |
62 | 63 |
|
63 | 64 | namespace py_tools { |
64 | | -libs_availability_t libs_availability(libs_scan_type_t scan_type) { |
| 65 | +libs_availability_t libs_availability(libs_scan_type_t scan_type, |
| 66 | + unsigned int scan_flags) { |
65 | 67 | // run only in py thread |
66 | 68 |
|
67 | 69 | libs_availability_t availability{}; |
@@ -114,39 +116,50 @@ libs_availability_t libs_availability(libs_scan_type_t scan_type) { |
114 | 116 | LOGD("python version check py error: " << err.what()); |
115 | 117 | } |
116 | 118 |
|
117 | | - if (cpu_tools::cpuinfo().feature_flags & cpu_tools::feature_flags_t::avx) { |
118 | | - try { |
119 | | - LOGD("checking: torch cuda"); |
120 | | - auto torch_cuda = py::module_::import("torch.cuda"); |
121 | | - auto torch_ver = py::module_::import("torch.version"); |
122 | | - if (torch_cuda.attr("is_available")().cast<bool>()) { |
123 | | - try { |
124 | | - auto cuda_ver = torch_ver.attr("cuda").cast<std::string>(); |
125 | | - LOGD("torch cuda version: " << cuda_ver); |
126 | | - availability.torch_cuda = !cuda_ver.empty(); |
127 | | - } catch ([[maybe_unused]] const py::cast_error& err) { |
128 | | - } |
129 | | - try { |
130 | | - auto hip_ver = torch_ver.attr("hip").cast<std::string>(); |
131 | | - LOGD("torch hip version: " << hip_ver); |
132 | | - availability.torch_hip = !hip_ver.empty(); |
133 | | - } catch ([[maybe_unused]] const py::cast_error& err) { |
| 119 | + if ((scan_flags & settings::ScanFlagNoTorchCuda) > 0) { |
| 120 | + LOGD("checking: torch cuda (skipped)"); |
| 121 | + } else { |
| 122 | + if (cpu_tools::cpuinfo().feature_flags & |
| 123 | + cpu_tools::feature_flags_t::avx) { |
| 124 | + try { |
| 125 | + LOGD("checking: torch cuda"); |
| 126 | + auto torch_cuda = py::module_::import("torch.cuda"); |
| 127 | + auto torch_ver = py::module_::import("torch.version"); |
| 128 | + if (torch_cuda.attr("is_available")().cast<bool>()) { |
| 129 | + try { |
| 130 | + auto cuda_ver = |
| 131 | + torch_ver.attr("cuda").cast<std::string>(); |
| 132 | + LOGD("torch cuda version: " << cuda_ver); |
| 133 | + availability.torch_cuda = !cuda_ver.empty(); |
| 134 | + } catch ([[maybe_unused]] const py::cast_error& err) { |
| 135 | + } |
| 136 | + try { |
| 137 | + auto hip_ver = |
| 138 | + torch_ver.attr("hip").cast<std::string>(); |
| 139 | + LOGD("torch hip version: " << hip_ver); |
| 140 | + availability.torch_hip = !hip_ver.empty(); |
| 141 | + } catch ([[maybe_unused]] const py::cast_error& err) { |
| 142 | + } |
134 | 143 | } |
| 144 | + } catch (const std::exception& err) { |
| 145 | + LOGD("torch cuda check py error: " << err.what()); |
135 | 146 | } |
136 | | - } catch (const std::exception& err) { |
137 | | - LOGD("torch cuda check py error: " << err.what()); |
138 | 147 | } |
139 | 148 | } |
140 | 149 |
|
141 | | - try { |
142 | | - LOGD("checking: ctranslate2-cuda"); |
143 | | - auto ct2 = py::module_::import("ctranslate2"); |
144 | | - LOGD("ctranslate2 version: " |
145 | | - << ct2.attr("__version__").cast<std::string>()); |
146 | | - availability.ctranslate2_cuda = |
147 | | - py::len(ct2.attr("get_supported_compute_types")("cuda")) > 0; |
148 | | - } catch (const std::exception& err) { |
149 | | - LOGD("ctranslate2-cuda check py error: " << err.what()); |
| 150 | + if ((scan_flags & settings::ScanFlagNoCt2Cuda) > 0) { |
| 151 | + LOGD("checking: ctranslate2-cuda (skipped)"); |
| 152 | + } else { |
| 153 | + try { |
| 154 | + LOGD("checking: ctranslate2-cuda"); |
| 155 | + auto ct2 = py::module_::import("ctranslate2"); |
| 156 | + LOGD("ctranslate2 version: " |
| 157 | + << ct2.attr("__version__").cast<std::string>()); |
| 158 | + availability.ctranslate2_cuda = |
| 159 | + py::len(ct2.attr("get_supported_compute_types")("cuda")) > 0; |
| 160 | + } catch (const std::exception& err) { |
| 161 | + LOGD("ctranslate2-cuda check py error: " << err.what()); |
| 162 | + } |
150 | 163 | } |
151 | 164 |
|
152 | 165 | if (scan_type == libs_scan_type_t::off_all_enabled || |
|
0 commit comments