Skip to content

Commit a83a519

Browse files
committed
[ruby/matrix] Optimize **
Avoiding recursive call would imply iterating bits starting from most significant, which is not easy to do efficiently. Any saving would be dwarfed by the multiplications anyways. [Feature #15233]
1 parent 3b5b309 commit a83a519

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

lib/matrix.rb

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,26 +1233,49 @@ def inverse
12331233
# # => 67 96
12341234
# # 48 99
12351235
#
1236-
def **(other)
1237-
case other
1236+
def **(exp)
1237+
case exp
12381238
when Integer
1239-
x = self
1240-
if other <= 0
1241-
x = self.inverse
1242-
return self.class.identity(self.column_count) if other == 0
1243-
other = -other
1244-
end
1245-
z = nil
1246-
loop do
1247-
z = z ? z * x : x if other[0] == 1
1248-
return z if (other >>= 1).zero?
1249-
x *= x
1239+
case
1240+
when exp == 0
1241+
_make_sure_it_is_invertible = inverse
1242+
self.class.identity(column_count)
1243+
when exp < 0
1244+
inverse.power_int(-exp)
1245+
else
1246+
power_int(exp)
12501247
end
12511248
when Numeric
12521249
v, d, v_inv = eigensystem
1253-
v * self.class.diagonal(*d.each(:diagonal).map{|e| e ** other}) * v_inv
1250+
v * self.class.diagonal(*d.each(:diagonal).map{|e| e ** exp}) * v_inv
1251+
else
1252+
raise ErrOperationNotDefined, ["**", self.class, exp.class]
1253+
end
1254+
end
1255+
1256+
protected def power_int(exp)
1257+
# assumes `exp` is an Integer > 0
1258+
#
1259+
# Previous algorithm:
1260+
# build M**2, M**4 = (M**2)**2, M**8, ... and multiplying those you need
1261+
# e.g. M**0b1011 = M**11 = M * M**2 * M**8
1262+
# ^ ^
1263+
# (highlighted the 2 out of 5 multiplications involving `M * x`)
1264+
#
1265+
# Current algorithm has same number of multiplications but with lower exponents:
1266+
# M**11 = M * (M * M**4)**2
1267+
# ^ ^ ^
1268+
# (highlighted the 3 out of 5 multiplications involving `M * x`)
1269+
#
1270+
# This should be faster for all (non nil-potent) matrices.
1271+
case
1272+
when exp == 1
1273+
self
1274+
when exp.odd?
1275+
self * power_int(exp - 1)
12541276
else
1255-
raise ErrOperationNotDefined, ["**", self.class, other.class]
1277+
sqrt = power_int(exp / 2)
1278+
sqrt * sqrt
12561279
end
12571280
end
12581281

test/matrix/test_matrix.rb

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,12 @@ def test_exp
448448
assert_equal(Matrix[[67,96],[48,99]], Matrix[[7,6],[3,9]] ** 2)
449449
assert_equal(Matrix.I(5), Matrix.I(5) ** -1)
450450
assert_raise(Matrix::ErrOperationNotDefined) { Matrix.I(5) ** Object.new }
451+
452+
m = Matrix[[0,2],[1,0]]
453+
exp = 0b11101000
454+
assert_equal(Matrix.scalar(2, 1 << (exp/2)), m ** exp)
455+
exp = 0b11101001
456+
assert_equal(Matrix[[0, 2 << (exp/2)], [1 << (exp/2), 0]], m ** exp)
451457
end
452458

453459
def test_det

0 commit comments

Comments
 (0)