Skip to main content

Subclassing `ndarray` following tutorial yields unexpected results (i.e. partial memory, some attributes are remembers others are lost)

I think I followed the subclassing tutorial correctly. I have a very simple example. It works when I run the code once. When I rerun a cell in a Jupyter notebook, then the class breaks and it "forgets" state (well it remembers the stuff I added, it forgets that transpose I did to numpy array). See code below.

Below I implement three simple classes NamedAxis, NamedAxes, and NamedArray (yes I am aware of xarray, this is for my own learning purposes). Mostly it works fine. However, I notice something very frustrating when I rerun flip

from copy import deepcopy
from dataclasses import dataclass, field
from typing import List, Dict, Union, Optional, Any, Callable, TypeVar, Generic, Type, cast, Tuple
import numpy as np, pandas as pd


@dataclass
class NamedAxis:
    # name of axis
    name: str
    # index of axis
    axis: Optional[int] = None

    def __str__(self):
        return f'{self.name}({self.axis})'
    
    __repr__ = __str__

    def copy(self) -> 'NamedAxis':
        copy = deepcopy(self)
        return copy

    
@dataclass    
class NamedAxes:
    axes: Union[List[NamedAxis], Tuple[NamedAxis]]
    name: Optional[str] = 'NamedAxes'
    umap: Dict[str, NamedAxis] = field(default_factory=dict, init=False, repr=False)
    
    def __post_init__(self):
        # assign unique id to each axis
        for i, axis in enumerate(self.axes):
            axis.axis = i
        
        self.umap = {ax.axis: ax for ax in self.axes}

    @property
    def ndim(self):
        return len(self.axes)
    
    @property
    def anames(self):
        # names in current location
        return [str(ax.name) for ax in self.axes]
    
    @property
    def aidxs(self):
        # original location as ax.axis should never be changed
        return [int(ax.axis) for ax in self.axes]
    
    @property
    def alocs(self):
        # current location
        return list(range(len(self)))

    def __getitem__(self, key:Union[int, str, NamedAxis]) -> NamedAxis:
        # NOTE: this gets current location of axis, not original location
        if isinstance(key, int):
            return self.axes[key]
        
        # NOTE: this gets location based off original location
        elif isinstance(key, NamedAxis):
            return self.umap[key.axis]

        # NOTE: this gets location based off original location
        elif isinstance(key, str):
            for ax in self.umap.values():
                if key == ax.name:
                    return ax
                
                elif key == str(ax.axis):
                    return ax    
        else:
            raise KeyError(f'Key {key} not found in {self.name}')
        
    def __str__(self):
        _str = f'{self.name}(' + ', '.join(self.anames) + ')'
        return _str
    
    __repr__ = __str__
    
    def __iter__(self):
        return iter(self.axes)

    def __len__(self):
        return len(self.axes)
    
    def copy(self):
        copy = deepcopy(self)
        copy.umap = self.umap.copy()
        return copy
    
    def index(self, key:Union[int, str, NamedAxis]):
        ax = self[key]
        return self.axes.index(ax)

    def transpose(self, *order:Union[str, int, NamedAxis]):
        # check input and convert to axes
        update_axes = [self[key] for key in order]

        # gather the axes that are not in the provided order
        needed_axes = [ax for ax in self.axes if ax not in update_axes]
        
        # the new order of axes is the updated axes followed by the needed axes
        new_order = update_axes + needed_axes
        print('NamedAxes.transpose:\t', self.name, self.axes, new_order)

        # rearrange axes according to the new order
        self.axes = new_order
        return self


a, b, c = NamedAxis('axis-a'), NamedAxis('axis-b'), NamedAxis('axis-c')
abc = NamedAxes((a, b, c))
abc




class NamedArray(np.ndarray):
    DIMS = NamedAxes([NamedAxis('axis-a'), NamedAxis('axis-b'), NamedAxis('axis-c')], name='Trajectories')
    
    def __new__(cls, arr, dims=None):
        obj = np.asarray(arr).view(cls)        
        obj.dims = (dims or cls.DIMS).copy()
        return obj
    
    def __new__(cls, arr, dims:NamedAxes=None):
        # Input array is an already formed ndarray instance
        # We first cast to be our class type
        obj = np.asarray(arr).view(cls)
        # add the new attribute to the created instance
        obj.dims = (dims or cls.DIMS).copy()      
        # Finally, we must return the newly created object:
        return obj
            
    def __array_finalize__(self, obj):
        print('finalize, dims=', getattr(obj, 'dims', None))
        print('finalize, obj=', obj)
        if obj is None: return        
        self.dims = getattr(obj, 'dims', self.DIMS.copy())

        # Ensure the indices are in the correct range
        shape = self.shape
        if len(shape) != len(self.dims):
            raise ValueError('NamedArray must have {len(self.dims)} dimensions, but got {len(shape)}.')
        
    def __array_wrap__(self, out, dims=None):
        print('In __array_wrap__:')
        print('   self is %s' % repr(self))
        print('   arr is %s' % repr(out))
        # then just call the parent
        return super().__array_wrap__(self, out, dims)
    
    
    @property
    def dim_names(self):
        return tuple(self.dims.anames)
            
    @property
    def dim_str(self):
        _str = ', '.join([f'{s} {n}' for s, n in zip(self.shape, self.dim_names)])
        return f'({_str})'
             
    def __repr__(self):
        base = super(NamedArray, self).__repr__()        
        first_line = base.split('\n')[0]
        spaces = 0
        for s in first_line:            
            if s.isdigit():
                break
            spaces += 1
        spaces = ' ' * (spaces - 1)
        return f'{base}\n{spaces}{self.dim_str}'
    

    
    def flip(self, axes:Union[str, int, NamedAxis]=None):
        # I tried transpose as well
        print(self.dims.axes)    
        # Get the order of axes indices        
        new_idxs = [self.dims.index(self.dims[ax]) for ax in axes]
        print(axes, new_idxs)

        # Transpose the NamedAxes
        self.dims.transpose(*axes)        
        print(new_idxs, self.__array_interface__['shape'])
        
        # Transpose the underlying numpy array
        self = np.transpose(self, axes=new_idxs)
        # self.transpose(*new_idxs)
        

        '''
        # NOTE: StackOverflow post edit / clarification
        I've tried this a few different ways including 
        `self.transpose()` as well as just `return np.transpose()`, 
        and trying to change the function flip to `transpose` etc. 
        This is just the version I am posting for brevity without 
        the 10 different `flip` implementations
        '''


        return self    

So lets make some dummy data:

arr = np.random.randint(0, 5, (2, 3, 4))
nar = NamedArray(arr)
nar
# (2 axis-a, 3 axis-b, 4 axis-c)

''' NOTE: flip is basically transpose, with the difference that 
`arr.transpose(1, 0, 2).transpose(1, 0, 2)` will do two transposes
but since we are using names and named indices, `nar.flip('b', 'a', 'c').flip('b', 'a', 'c')` should only do one. In other words `flip` is declarative, saying how we want the axes to be. Similar to einops / xarray
'''

nar.flip(('axis-c', 'axis-b', 'axis-a'))
# (4 axis-c, 3 axis-b, 2 axis-a)

So far so good. However, when I run the cell again

# (2 axis-a, 3 axis-b, 4 axis-c)
nar.flip(('axis-c', 'axis-b', 'axis-a'))
# (2 axis-c, 3 axis-b, 4 axis-a)

I have spent way too long debugging this and I can't figure it out.



source https://stackoverflow.com/questions/76440203/subclassing-ndarray-following-tutorial-yields-unexpected-results-i-e-partial

Comments

Popular posts from this blog

How to show number of registered users in Laravel based on usertype?

i'm trying to display data from the database in the admin dashboard i used this: <?php use Illuminate\Support\Facades\DB; $users = DB::table('users')->count(); echo $users; ?> and i have successfully get the correct data from the database but what if i want to display a specific data for example in this user table there is "usertype" that specify if the user is normal user or admin i want to user the same code above but to display a specific usertype i tried this: <?php use Illuminate\Support\Facades\DB; $users = DB::table('users')->count()->WHERE usertype =admin; echo $users; ?> but it didn't work, what am i doing wrong? source https://stackoverflow.com/questions/68199726/how-to-show-number-of-registered-users-in-laravel-based-on-usertype

Why is my reports service not connecting?

I am trying to pull some data from a Postgres database using Node.js and node-postures but I can't figure out why my service isn't connecting. my routes/index.js file: const express = require('express'); const router = express.Router(); const ordersCountController = require('../controllers/ordersCountController'); const ordersController = require('../controllers/ordersController'); const weeklyReportsController = require('../controllers/weeklyReportsController'); router.get('/orders_count', ordersCountController); router.get('/orders', ordersController); router.get('/weekly_reports', weeklyReportsController); module.exports = router; My controllers/weeklyReportsController.js file: const weeklyReportsService = require('../services/weeklyReportsService'); const weeklyReportsController = async (req, res) => { try { const data = await weeklyReportsService; res.json({data}) console...

How to split a rinex file if I need 24 hours data

Trying to divide rinex file using the command gfzrnx but getting this error. While doing that getting this error msg 'gfzrnx' is not recognized as an internal or external command Trying to split rinex file using the command gfzrnx. also install'gfzrnx'. my doubt is I need to run this program in 'gfzrnx' or in 'cmdprompt'. I am expecting a rinex file with 24 hrs or 1 day data.I Have 48 hrs data in RINEX format. Please help me to solve this issue. source https://stackoverflow.com/questions/75385367/how-to-split-a-rinex-file-if-i-need-24-hours-data