Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,12 +1106,13 @@ def test_all_invalid_plot_data(self):

@slow
def test_partially_invalid_plot_data(self):
kinds = 'line', 'bar', 'barh', 'kde', 'density'
df = DataFrame(randn(10, 2), dtype=object)
df[np.random.rand(df.shape[0]) > 0.5] = 'a'
for kind in kinds:
with tm.assertRaises(TypeError):
df.plot(kind=kind)
with tm.RNGContext(42):
kinds = 'line', 'bar', 'barh', 'kde', 'density'
df = DataFrame(randn(10, 2), dtype=object)
df[np.random.rand(df.shape[0]) > 0.5] = 'a'
for kind in kinds:
with tm.assertRaises(TypeError):
df.plot(kind=kind)

def test_invalid_kind(self):
df = DataFrame(randn(10, 2))
Expand Down
14 changes: 13 additions & 1 deletion pandas/tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import sys
from pandas import Series
from pandas.util.testing import (
assert_almost_equal, assertRaisesRegexp, raise_with_traceback, assert_series_equal
assert_almost_equal, assertRaisesRegexp, raise_with_traceback, assert_series_equal,
RNGContext
)

# let's get meta.
Expand Down Expand Up @@ -153,3 +154,14 @@ def test_not_equal(self):
# ATM meta data is not checked in assert_series_equal
# self._assert_not_equal(Series(range(3)),Series(range(3),name='foo'),check_names=True)


class TestRNGContext(unittest.TestCase):

def test_RNGContext(self):
expected0 = 1.764052345967664
expected1 = 1.6243453636632417

with RNGContext(0):
with RNGContext(1):
self.assertEqual(np.random.randn(), expected1)
self.assertEqual(np.random.randn(), expected0)
30 changes: 30 additions & 0 deletions pandas/util/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,3 +1528,33 @@ def skip_if_no_ne(engine='numexpr'):
def disabled(t):
t.disabled = True
return t


class RNGContext(object):
"""
Context manager to set the numpy random number generator speed. Returns
to the original value upon exiting the context manager.

Parameters
----------
seed : int
Seed for numpy.random.seed

Examples
--------

with RNGContext(42):
np.random.randn()
"""

def __init__(self, seed):
self.seed = seed

def __enter__(self):

self.start_state = np.random.get_state()
np.random.seed(self.seed)

def __exit__(self, exc_type, exc_value, traceback):

np.random.set_state(self.start_state)