Skip to content

Conversation

itonito
Copy link

@itonito itonito commented Aug 17, 2025

Reference Issues/PRs

Fixes #27342

What does this implement/fix? Explain your changes.

Added pos_label and tests

Any other comments?

Copy link

github-actions bot commented Aug 17, 2025

❌ Linting issues

This PR is introducing linting issues. Here's a summary of the issues. Note that you can avoid having linting issues by enabling pre-commit hooks. Instructions to enable them can be found here.

You can see the details of the linting issues under the lint job here


ruff format

ruff detected issues. Please run ruff format locally and push the changes. Here you can see the detected issues. Note that the installed ruff version is ruff=0.11.7.


--- sklearn/preprocessing/tests/test_target_encoder.py
+++ sklearn/preprocessing/tests/test_target_encoder.py
@@ -722,28 +722,20 @@
     y_binary = np.array([0, 1, 0, 1, 0, 1])
 
     # Default pos_label=1 should give same result as LabelBinarizer default
-    enc_default = TargetEncoder(
-        target_type="binary", cv=2, random_state=0
-    )
+    enc_default = TargetEncoder(target_type="binary", cv=2, random_state=0)
     enc_default.fit(X, y_binary)
 
     # Custom pos_label=0 should invert the encoding
-    enc_custom = TargetEncoder(
-        target_type="binary", pos_label=0, cv=2, random_state=0
-    )
+    enc_custom = TargetEncoder(target_type="binary", pos_label=0, cv=2, random_state=0)
     enc_custom.fit(X, y_binary)
 
     # The encodings should be different due to different pos_label
-    assert not np.allclose(
-        enc_default.encodings_[0], enc_custom.encodings_[0]
-    )
+    assert not np.allclose(enc_default.encodings_[0], enc_custom.encodings_[0])
 
     # Test multiclass classification with different pos_label values
     y_multiclass = np.array([0, 1, 2, 0, 1, 2])
 
-    enc_multi_default = TargetEncoder(
-        target_type="multiclass", cv=2, random_state=0
-    )
+    enc_multi_default = TargetEncoder(target_type="multiclass", cv=2, random_state=0)
     enc_multi_default.fit(X, y_multiclass)
 
     enc_multi_custom = TargetEncoder(
@@ -768,31 +760,22 @@
     y = np.array([0, 1, 0, 1, 0, 1])
 
     # Test with pos_label=1 (default)
-    enc_te = TargetEncoder(
-        target_type="binary", cv=2, random_state=0
-    )
+    enc_te = TargetEncoder(target_type="binary", cv=2, random_state=0)
     enc_te.fit(X, y)
 
     # Test with pos_label=0
-    enc_te_0 = TargetEncoder(
-        target_type="binary", pos_label=0, cv=2, random_state=0
-    )
+    enc_te_0 = TargetEncoder(target_type="binary", pos_label=0, cv=2, random_state=0)
     enc_te_0.fit(X, y)
 
     # The encodings should be different but the feature names should be consistent
-    assert not np.allclose(
-        enc_te.encodings_[0], enc_te_0.encodings_[0]
-    )
+    assert not np.allclose(enc_te.encodings_[0], enc_te_0.encodings_[0])
 
     # Both should have the same classes
     assert_array_equal(enc_te.classes_, enc_te_0.classes_)
 
     # Test transform output shapes are the same
     X_test = np.array([["a", "b", "c"]], dtype=object).T
-    assert (
-        enc_te.transform(X_test).shape
-        == enc_te_0.transform(X_test).shape
-    )
+    assert enc_te.transform(X_test).shape == enc_te_0.transform(X_test).shape
 
 
 def test_pos_label_parameter_edge_cases():
@@ -806,9 +789,7 @@
     y = np.array([0, 1, 0, 1])
 
     # Test with pos_label that doesn't exist in y (should still work)
-    enc = TargetEncoder(
-        target_type="binary", pos_label=99, cv=2, random_state=0
-    )
+    enc = TargetEncoder(target_type="binary", pos_label=99, cv=2, random_state=0)
     enc.fit(X, y)
 
     # Should still encode correctly
@@ -817,9 +798,7 @@
 
     # Test with pos_label that is a string (should work if y contains strings)
     y_str = np.array(["neg", "pos", "neg", "pos"])
-    enc_str = TargetEncoder(
-        target_type="binary", pos_label="pos", cv=2, random_state=0
-    )
+    enc_str = TargetEncoder(target_type="binary", pos_label="pos", cv=2, random_state=0)
     enc_str.fit(X, y_str)
 
     # Should encode correctly

1 file would be reformatted, 925 files already formatted

Generated for commit: ad4ee4e. Link to the linter CI: here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH Add pos_label parameter to TargetEncoder
1 participant