Subclassing `ndarray` following tutorial yields unexpected results (i.e. partial memory, some attributes are remembers others are lost)
I think I followed the subclassing tutorial correctly. I have a very simple example. It works when I run the code once. When I rerun a cell in a Jupyter notebook, then the class breaks and it "forgets" state (well it remembers the stuff I added, it forgets that transpose I did to numpy array). See code below.
Below I implement three simple classes NamedAxis, NamedAxes, and NamedArray (yes I am aware of xarray, this is for my own learning purposes). Mostly it works fine. However, I notice something very frustrating when I rerun flip
from copy import deepcopy
from dataclasses import dataclass, field
from typing import List, Dict, Union, Optional, Any, Callable, TypeVar, Generic, Type, cast, Tuple
import numpy as np, pandas as pd
@dataclass
class NamedAxis:
# name of axis
name: str
# index of axis
axis: Optional[int] = None
def __str__(self):
return f'{self.name}({self.axis})'
__repr__ = __str__
def copy(self) -> 'NamedAxis':
copy = deepcopy(self)
return copy
@dataclass
class NamedAxes:
axes: Union[List[NamedAxis], Tuple[NamedAxis]]
name: Optional[str] = 'NamedAxes'
umap: Dict[str, NamedAxis] = field(default_factory=dict, init=False, repr=False)
def __post_init__(self):
# assign unique id to each axis
for i, axis in enumerate(self.axes):
axis.axis = i
self.umap = {ax.axis: ax for ax in self.axes}
@property
def ndim(self):
return len(self.axes)
@property
def anames(self):
# names in current location
return [str(ax.name) for ax in self.axes]
@property
def aidxs(self):
# original location as ax.axis should never be changed
return [int(ax.axis) for ax in self.axes]
@property
def alocs(self):
# current location
return list(range(len(self)))
def __getitem__(self, key:Union[int, str, NamedAxis]) -> NamedAxis:
# NOTE: this gets current location of axis, not original location
if isinstance(key, int):
return self.axes[key]
# NOTE: this gets location based off original location
elif isinstance(key, NamedAxis):
return self.umap[key.axis]
# NOTE: this gets location based off original location
elif isinstance(key, str):
for ax in self.umap.values():
if key == ax.name:
return ax
elif key == str(ax.axis):
return ax
else:
raise KeyError(f'Key {key} not found in {self.name}')
def __str__(self):
_str = f'{self.name}(' + ', '.join(self.anames) + ')'
return _str
__repr__ = __str__
def __iter__(self):
return iter(self.axes)
def __len__(self):
return len(self.axes)
def copy(self):
copy = deepcopy(self)
copy.umap = self.umap.copy()
return copy
def index(self, key:Union[int, str, NamedAxis]):
ax = self[key]
return self.axes.index(ax)
def transpose(self, *order:Union[str, int, NamedAxis]):
# check input and convert to axes
update_axes = [self[key] for key in order]
# gather the axes that are not in the provided order
needed_axes = [ax for ax in self.axes if ax not in update_axes]
# the new order of axes is the updated axes followed by the needed axes
new_order = update_axes + needed_axes
print('NamedAxes.transpose:\t', self.name, self.axes, new_order)
# rearrange axes according to the new order
self.axes = new_order
return self
a, b, c = NamedAxis('axis-a'), NamedAxis('axis-b'), NamedAxis('axis-c')
abc = NamedAxes((a, b, c))
abc
class NamedArray(np.ndarray):
DIMS = NamedAxes([NamedAxis('axis-a'), NamedAxis('axis-b'), NamedAxis('axis-c')], name='Trajectories')
def __new__(cls, arr, dims=None):
obj = np.asarray(arr).view(cls)
obj.dims = (dims or cls.DIMS).copy()
return obj
def __new__(cls, arr, dims:NamedAxes=None):
# Input array is an already formed ndarray instance
# We first cast to be our class type
obj = np.asarray(arr).view(cls)
# add the new attribute to the created instance
obj.dims = (dims or cls.DIMS).copy()
# Finally, we must return the newly created object:
return obj
def __array_finalize__(self, obj):
print('finalize, dims=', getattr(obj, 'dims', None))
print('finalize, obj=', obj)
if obj is None: return
self.dims = getattr(obj, 'dims', self.DIMS.copy())
# Ensure the indices are in the correct range
shape = self.shape
if len(shape) != len(self.dims):
raise ValueError('NamedArray must have {len(self.dims)} dimensions, but got {len(shape)}.')
def __array_wrap__(self, out, dims=None):
print('In __array_wrap__:')
print(' self is %s' % repr(self))
print(' arr is %s' % repr(out))
# then just call the parent
return super().__array_wrap__(self, out, dims)
@property
def dim_names(self):
return tuple(self.dims.anames)
@property
def dim_str(self):
_str = ', '.join([f'{s} {n}' for s, n in zip(self.shape, self.dim_names)])
return f'({_str})'
def __repr__(self):
base = super(NamedArray, self).__repr__()
first_line = base.split('\n')[0]
spaces = 0
for s in first_line:
if s.isdigit():
break
spaces += 1
spaces = ' ' * (spaces - 1)
return f'{base}\n{spaces}{self.dim_str}'
def flip(self, axes:Union[str, int, NamedAxis]=None):
# I tried transpose as well
print(self.dims.axes)
# Get the order of axes indices
new_idxs = [self.dims.index(self.dims[ax]) for ax in axes]
print(axes, new_idxs)
# Transpose the NamedAxes
self.dims.transpose(*axes)
print(new_idxs, self.__array_interface__['shape'])
# Transpose the underlying numpy array
self = np.transpose(self, axes=new_idxs)
# self.transpose(*new_idxs)
'''
# NOTE: StackOverflow post edit / clarification
I've tried this a few different ways including
`self.transpose()` as well as just `return np.transpose()`,
and trying to change the function flip to `transpose` etc.
This is just the version I am posting for brevity without
the 10 different `flip` implementations
'''
return self
So lets make some dummy data:
arr = np.random.randint(0, 5, (2, 3, 4))
nar = NamedArray(arr)
nar
# (2 axis-a, 3 axis-b, 4 axis-c)
''' NOTE: flip is basically transpose, with the difference that
`arr.transpose(1, 0, 2).transpose(1, 0, 2)` will do two transposes
but since we are using names and named indices, `nar.flip('b', 'a', 'c').flip('b', 'a', 'c')` should only do one. In other words `flip` is declarative, saying how we want the axes to be. Similar to einops / xarray
'''
nar.flip(('axis-c', 'axis-b', 'axis-a'))
# (4 axis-c, 3 axis-b, 2 axis-a)
So far so good. However, when I run the cell again
# (2 axis-a, 3 axis-b, 4 axis-c)
nar.flip(('axis-c', 'axis-b', 'axis-a'))
# (2 axis-c, 3 axis-b, 4 axis-a)
I have spent way too long debugging this and I can't figure it out.
source https://stackoverflow.com/questions/76440203/subclassing-ndarray-following-tutorial-yields-unexpected-results-i-e-partial
Comments
Post a Comment