Numpy Matrix class: Default constructor attributes

2020-04-05 09:16发布

I want to implement my own matrix-class that inherits from numpy's matrix class.

numpy's matrix constructor requires an attribute, something like ("1 2; 3 4'"). In contrast, my constructor should require no attributes and should set a default attribute to the super-constructor.

That's what I did:

import numpy as np

class MyMatrix(np.matrix):
    def __init__(self):
        super(MyMatrix, self).__init__("1 2; 3 4")

if __name__ == "__main__":
    matrix = MyMatrix()

There must be a stupid mistake in this code since I keep getting this error:

this_matrix = np.matrix()
TypeError: __new__() takes at least 2 arguments (1 given)

I'm really clueless about that and googling didn't help so far.

Thanks!

1条回答
一夜七次
2楼-- · 2020-04-05 09:32

Good question!

From looking at the source, it seems as though np.matrix sets the data argument in __new__, not in __init__. This is counterintuitive behaviour, though I'm sure there's a good reason for it.

Anyway, the following works for me:

class MyMatrix(np.matrix):
    def __new__(cls):
        # note that we have to send cls to super's __new__, even though we gave it to super already.
        # I think this is because __new__ is technically a staticmethod even though it should be a classmethod
        return super(MyMatrix, cls).__new__(cls, "1 2; 3 4")

mat = MyMatrix()

print mat
# outputs [[1 2] [3 4]]

Addendum: you might want to consider using a factory function, rather than a subclass, for the behaviour you want. This would give you the following code, which is much shorter and clearer, and doesn't depend on the __new__-vs-__init__ implementation detail:

def mymatrix():
    return np.matrix('1 2; 3 4')

mat = mymatrix()
print mat
# outputs [[1 2] [3 4]]

Of course, you might need a subclass for other reasons.

查看更多
登录 后发表回答