@@ -153,7 +153,7 @@ def __init__(
153153        """ 
154154        seed1 , seed2  =  random_seed .get_seed (seed )
155155        # If op level seed is not set, use whatever graph level seed is returned 
156-         np .random .seed (seed1  if  seed  is  None  else  seed2 )
156+         self . _rng   =   np .random .default_rng (seed1  if  seed  is  None  else  seed2 )
157157        dtype  =  dtypes .as_dtype (dtype ).base_dtype 
158158        if  dtype  not  in   (dtypes .uint8 , dtypes .float32 ):
159159            raise  TypeError ("Invalid image dtype %r, expected uint8 or float32"  %  dtype )
@@ -211,7 +211,7 @@ def next_batch(self, batch_size, fake_data=False, shuffle=True):
211211        # Shuffle for the first epoch 
212212        if  self ._epochs_completed  ==  0  and  start  ==  0  and  shuffle :
213213            perm0  =  np .arange (self ._num_examples )
214-             np . random .shuffle (perm0 )
214+             self . _rng .shuffle (perm0 )
215215            self ._images  =  self .images [perm0 ]
216216            self ._labels  =  self .labels [perm0 ]
217217        # Go to the next epoch 
@@ -225,7 +225,7 @@ def next_batch(self, batch_size, fake_data=False, shuffle=True):
225225            # Shuffle the data 
226226            if  shuffle :
227227                perm  =  np .arange (self ._num_examples )
228-                 np . random .shuffle (perm )
228+                 self . _rng .shuffle (perm )
229229                self ._images  =  self .images [perm ]
230230                self ._labels  =  self .labels [perm ]
231231            # Start next epoch 
0 commit comments