Skip to content

Commit 2f9375a

Browse files
committed
clean up ssl
1 parent d058f18 commit 2f9375a

File tree

2 files changed

+67
-41
lines changed

2 files changed

+67
-41
lines changed

Lib/test/test_ssl.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,7 +1280,6 @@ def getpass(self):
12801280
ctx.load_cert_chain(CERTFILE, password=getpass_exception)
12811281

12821282
@threading_helper.requires_working_threading()
1283-
@unittest.expectedFailure # TODO: RUSTPYTHON
12841283
def test_load_cert_chain_thread_safety(self):
12851284
# gh-134698: _ssl detaches the thread state (and as such,
12861285
# releases the GIL and critical sections) around expensive
@@ -1568,7 +1567,6 @@ def _assert_context_options(self, ctx):
15681567
self.assertEqual(ctx.options & ssl.OP_LEGACY_SERVER_CONNECT,
15691568
0 if IS_OPENSSL_3_0_0 else ssl.OP_LEGACY_SERVER_CONNECT)
15701569

1571-
@unittest.expectedFailure # TODO: RUSTPYTHON
15721570
def test_create_default_context(self):
15731571
ctx = ssl.create_default_context()
15741572

@@ -3135,7 +3133,6 @@ def test_ecc_cert(self):
31353133

31363134
@unittest.skipUnless(IS_OPENSSL_3_0_0,
31373135
"test requires RFC 5280 check added in OpenSSL 3.0+")
3138-
@unittest.expectedFailure # TODO: RUSTPYTHON
31393136
def test_verify_strict(self):
31403137
# verification fails by default, since the server cert is non-conforming
31413138
client_context = ssl.create_default_context()
@@ -3567,7 +3564,6 @@ def test_starttls(self):
35673564
else:
35683565
s.close()
35693566

3570-
@unittest.expectedFailure # TODO: RUSTPYTHON
35713567
def test_socketserver(self):
35723568
"""Using socketserver to create and manage SSL connections."""
35733569
server = make_https_server(self, certfile=SIGNED_CERTFILE)
@@ -5343,7 +5339,6 @@ def call_after_accept(conn_to_client):
53435339
wrap_error = None
53445340
server = None
53455341

5346-
@unittest.expectedFailure # TODO: RUSTPYTHON
53475342
def test_https_client_non_tls_response_ignored(self):
53485343
server_responding = threading.Event()
53495344

stdlib/src/ssl.rs

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ mod _ssl {
3939
socket::{self, PySocket},
4040
vm::{
4141
Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
42-
builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyType, PyTypeRef, PyWeak},
42+
builtins::{PyBaseExceptionRef, PyBytesRef, PyListRef, PyStrRef, PyTypeRef, PyWeak},
4343
class_or_notimplemented,
4444
convert::{ToPyException, ToPyObject},
4545
exceptions,
@@ -66,7 +66,8 @@ mod _ssl {
6666
ffi::CStr,
6767
fmt,
6868
io::{Read, Write},
69-
path::Path,
69+
path::{Path, PathBuf},
70+
sync::LazyLock,
7071
time::Instant,
7172
};
7273

@@ -191,7 +192,8 @@ mod _ssl {
191192

192193
#[pyattr(name = "_OPENSSL_API_VERSION")]
193194
fn _openssl_api_version(_vm: &VirtualMachine) -> OpensslVersionInfo {
194-
let openssl_api_version = i64::from_str_radix(env!("OPENSSL_API_VERSION"), 16).unwrap();
195+
let openssl_api_version = i64::from_str_radix(env!("OPENSSL_API_VERSION"), 16)
196+
.expect("OPENSSL_API_VERSION is malformed");
195197
parse_version_info(openssl_api_version)
196198
}
197199

@@ -249,7 +251,8 @@ mod _ssl {
249251
/// SSL/TLS connection terminated abruptly.
250252
#[pyattr(name = "SSLEOFError", once)]
251253
fn ssl_eof_error(vm: &VirtualMachine) -> PyTypeRef {
252-
PyType::new_simple_heap("ssl.SSLEOFError", &ssl_error(vm), &vm.ctx).unwrap()
254+
vm.ctx
255+
.new_exception_type("ssl", "SSLEOFError", Some(vec![ssl_error(vm)]))
253256
}
254257

255258
type OpensslVersionInfo = (u8, u8, u8, u8, u8);
@@ -350,14 +353,17 @@ mod _ssl {
350353
}
351354

352355
type PyNid = (libc::c_int, String, String, Option<String>);
353-
fn obj2py(obj: &Asn1ObjectRef) -> PyNid {
356+
fn obj2py(obj: &Asn1ObjectRef, vm: &VirtualMachine) -> PyResult<PyNid> {
354357
let nid = obj.nid();
355-
(
356-
nid.as_raw(),
357-
nid.short_name().unwrap().to_owned(),
358-
nid.long_name().unwrap().to_owned(),
359-
obj2txt(obj, true),
360-
)
358+
let short_name = nid
359+
.short_name()
360+
.map_err(|_| vm.new_value_error("NID has no short name".to_owned()))?
361+
.to_owned();
362+
let long_name = nid
363+
.long_name()
364+
.map_err(|_| vm.new_value_error("NID has no long name".to_owned()))?
365+
.to_owned();
366+
Ok((nid.as_raw(), short_name, long_name, obj2txt(obj, true)))
361367
}
362368

363369
#[derive(FromArgs)]
@@ -371,55 +377,81 @@ mod _ssl {
371377
fn txt2obj(args: Txt2ObjArgs, vm: &VirtualMachine) -> PyResult<PyNid> {
372378
_txt2obj(&args.txt.to_cstring(vm)?, !args.name)
373379
.as_deref()
374-
.map(obj2py)
375380
.ok_or_else(|| vm.new_value_error(format!("unknown object '{}'", args.txt)))
381+
.and_then(|obj| obj2py(obj, vm))
376382
}
377383

378384
#[pyfunction]
379385
fn nid2obj(nid: libc::c_int, vm: &VirtualMachine) -> PyResult<PyNid> {
380386
_nid2obj(Nid::from_raw(nid))
381387
.as_deref()
382-
.map(obj2py)
383388
.ok_or_else(|| vm.new_value_error(format!("unknown NID {nid}")))
389+
.and_then(|obj| obj2py(obj, vm))
384390
}
385391

386-
fn get_cert_file_dir() -> (&'static Path, &'static Path) {
387-
let probe = probe();
388-
// on windows, these should be utf8 strings
389-
fn path_from_bytes(c: &CStr) -> &Path {
392+
// Lazily compute and cache cert file/dir paths
393+
static CERT_PATHS: LazyLock<(PathBuf, PathBuf)> = LazyLock::new(|| {
394+
fn path_from_cstr(c: &CStr) -> PathBuf {
390395
#[cfg(unix)]
391396
{
392397
use std::os::unix::ffi::OsStrExt;
393-
std::ffi::OsStr::from_bytes(c.to_bytes()).as_ref()
398+
std::ffi::OsStr::from_bytes(c.to_bytes()).into()
394399
}
395400
#[cfg(windows)]
396401
{
397-
c.to_str().unwrap().as_ref()
402+
// Use lossy conversion for potential non-UTF8
403+
PathBuf::from(c.to_string_lossy().as_ref())
398404
}
399405
}
400-
let cert_file = probe.cert_file.as_deref().unwrap_or_else(|| {
401-
path_from_bytes(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) })
402-
});
403-
let cert_dir = probe.cert_dir.as_deref().unwrap_or_else(|| {
404-
path_from_bytes(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) })
405-
});
406+
407+
let probe = probe();
408+
let cert_file = probe
409+
.cert_file
410+
.as_ref()
411+
.map(PathBuf::from)
412+
.unwrap_or_else(|| {
413+
path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_file()) })
414+
});
415+
let cert_dir = probe
416+
.cert_dir
417+
.as_ref()
418+
.map(PathBuf::from)
419+
.unwrap_or_else(|| {
420+
path_from_cstr(unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir()) })
421+
});
406422
(cert_file, cert_dir)
423+
});
424+
425+
fn get_cert_file_dir() -> (&'static Path, &'static Path) {
426+
let (cert_file, cert_dir) = &*CERT_PATHS;
427+
(cert_file.as_path(), cert_dir.as_path())
407428
}
408429

430+
// Lazily compute and cache cert environment variable names
431+
static CERT_ENV_NAMES: LazyLock<(String, String)> = LazyLock::new(|| {
432+
let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) }
433+
.to_string_lossy()
434+
.into_owned();
435+
let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) }
436+
.to_string_lossy()
437+
.into_owned();
438+
(cert_file_env, cert_dir_env)
439+
});
440+
409441
#[pyfunction]
410442
fn get_default_verify_paths(
411443
vm: &VirtualMachine,
412444
) -> PyResult<(&'static str, PyObjectRef, &'static str, PyObjectRef)> {
413-
let cert_file_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_file_env()) }
414-
.to_str()
415-
.unwrap();
416-
let cert_dir_env = unsafe { CStr::from_ptr(sys::X509_get_default_cert_dir_env()) }
417-
.to_str()
418-
.unwrap();
445+
let (cert_file_env, cert_dir_env) = &*CERT_ENV_NAMES;
419446
let (cert_file, cert_dir) = get_cert_file_dir();
420447
let cert_file = OsPath::new_str(cert_file).filename(vm);
421448
let cert_dir = OsPath::new_str(cert_dir).filename(vm);
422-
Ok((cert_file_env, cert_file, cert_dir_env, cert_dir))
449+
Ok((
450+
cert_file_env.as_str(),
451+
cert_file,
452+
cert_dir_env.as_str(),
453+
cert_dir,
454+
))
423455
}
424456

425457
#[pyfunction(name = "RAND_status")]
@@ -1869,12 +1901,12 @@ mod _ssl {
18691901
}
18701902

18711903
#[pygetset]
1872-
fn id(&self, vm: &VirtualMachine) -> PyObjectRef {
1904+
fn id(&self, vm: &VirtualMachine) -> PyBytesRef {
18731905
unsafe {
18741906
let mut len: libc::c_uint = 0;
18751907
let id_ptr = sys::SSL_SESSION_get_id(self.session, &mut len);
18761908
let id_slice = std::slice::from_raw_parts(id_ptr, len as usize);
1877-
vm.ctx.new_bytes(id_slice.to_vec()).into()
1909+
vm.ctx.new_bytes(id_slice.to_vec())
18781910
}
18791911
}
18801912

@@ -2259,8 +2291,7 @@ mod windows {
22592291
ValidUses::Oids(oids) => PyFrozenSet::from_iter(
22602292
vm,
22612293
oids.into_iter().map(|oid| vm.ctx.new_str(oid).into()),
2262-
)
2263-
.unwrap()
2294+
)?
22642295
.into_ref(&vm.ctx)
22652296
.into(),
22662297
};

0 commit comments

Comments
 (0)