矩阵相乘与矩阵快速幂

网易实习生笔试题,神奇手环,最开始的方法只过了60%,后来想到可以用矩阵的方法来计算。

一、矩阵相乘

下面的矩阵统一用二维数组来表示,先实现两个矩阵相乘的方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
public class Matrix {
/*
* 矩阵a与b相乘
* a的列数与b的行数相同才有意义
*/
public static int[][] multi(int[][] a, int[][] b){
if(a.length==0||a[0].length==0||b.length==0||b[0].length==0||a[0].length!=b.length){
return null;
}
int m=a.length, n=b[0].length;
int[][] re=new int[m][n];
for(int i=0; i<m; i++){
for(int j=0; j<n; j++){
int sum=0;
for(int x=0; x<b.length; x++){
sum+=a[i][x]*b[x][j];
}
re[i][j]=sum;
}
}
return re;
}

public static void printMatrix(int[][] matrix){
for(int[] nums: matrix){
System.out.println(Arrays.toString(nums));
}
}

public static void main(String[] args) {
int[][] m1=new int[][]{
{1,2},
{3,4},
{5,6}
};
int[][] m2=new int[][]{
{1,2,3},
{4,5,6}
};
printMatrix(multi(m1, m2));
}
}

输出

1
2
3
[9, 12, 15]
[19, 26, 33]
[29, 40, 51]

时间复杂度是O(n^3)

二、矩阵快速幂

矩阵快速幂可以高效地计算矩阵,将O(n)的时间复杂度降到O(logn)。

它的原理与数的快速幂是一样的,先来回忆一下数的快速幂

1、数的快速幂

  1. Pow(x, n),leetcode原题
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public class Solution {
public double myPow(double x, int n) {
if(n<0){
if(n==Integer.MIN_VALUE){
return 1.0/myPow(x,Integer.MAX_VALUE)*x;
}else{
return 1.0/myPow(x, -n);
}
}
if(n==0){
return 1.0;
}
double re=1.0;
for(; n>0; x*=x, n>>=1){
if((n&1)==1){
re*=x;
}
}
return re;
}
}

它的原理就是x^n=x^(n1+n2+n3+…..)=x^n1*xn2*x^n3…,

其中n=n1+n2+n3… 将n用二进制表示,则n1, n2, n3就可以表示为2^0, 2^1, 2^2….

举个例子,x^5=x^4*x^1

这样时间复杂度就从O(n)降到O(logn)。

2、矩阵快速幂

矩阵快速幂的原理与上面是一样的,在编程之美2.9中也介绍到了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
public class Matrix {
/*
* 矩阵a与b相乘
* a的列数与b的行数相同才有意义
*/
public static int[][] multi(int[][] a, int[][] b){
if(a.length==0||a[0].length==0||b.length==0||b[0].length==0||a[0].length!=b.length){
return null;
}
int m=a.length, n=b[0].length;
int[][] re=new int[m][n];
for(int i=0; i<m; i++){
for(int j=0; j<n; j++){
int sum=0;
for(int x=0; x<b.length; x++){
sum+=a[i][x]*b[x][j];
}
re[i][j]=sum;
}
}
return re;
}

public static void printMatrix(int[][] matrix){
for(int[] nums: matrix){
System.out.println(Arrays.toString(nums));
}
}


/*
* 计算矩阵的幂
* 方阵a的n次幂
*/
public static int[][] matrixPow(int[][] a, int n){
if(a.length==0||a[0].length==0||a.length!=a[0].length){
return null;
}
int len=a.length;
int[][] re=new int[len][len];
for(int i=0; i<len; i++){//单位阵
re[i][i]=1;
}
/*
* 这里需要注意一下,Java中二维数组是无法用clone实现深拷贝的
* 也就是说这个temp指向的内存和a是相同的
* 因为multi方法中每次都会返回一个新的数组,所以说a的内容不会被修改
*
*/
int[][] temp=a.clone();
while(n>0){
if((n&1)==1){
re=multi(re, temp);
}
temp=multi(temp, temp);
n>>=1;
}
return re;
}


public static void main(String[] args) {
int[][] matrix=new int[][]{
{1,2,3},
{4,5,6},
{7,8,9}
};
printMatrix(matrixPow(matrix, 4));
}
}

这里又两点需要注意

  1. 在注释中说到的clone方法无法实现二维数组的深拷贝,由于本例的特殊情况才这样写,正常情况下的二维数组拷贝应该遍历
  2. 当把上面的参数从4改成8后,输出结果就已经出现负数了,说明已经溢出。而且实际上,即使n等于4的时候,temp中就已经出现负数了。所以可见矩阵乘法的增长速度非常大。在实际情况中,应该使用字符串表示才是最可靠的。

三、矩阵快速幂的应用

1、斐波那契数列

最早接触到矩阵快速幂的时候应该就是编程之美2.9,斐波那契数列的问题。

不会编辑公式,具体问题和与原理就不描述了,直接上代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public class MatrixApp {	
/*
* 计算第n个斐波那契数列
*/
public static int fib(int n){
int[][] A=new int[][]{
{1,1},
{1,0}
};
int[][] An=Matrix.matrixPow(A, n-1);

int F1=1, F0=0;

return F1*An[0][0]+F0*An[0][1];
}


public static void main(String[] args) {
System.out.println(fib(10));
}
}

果然,当输入的n太大后,计算的就不准了,因为超过了int范围。

2、魔幻手环

终于到了这个网易的笔试题,魔幻手环

小易拥有一个拥有魔力的手环上面有n个数字(构成一个环),当这个魔力手环每次使用魔力的时候就会发生一种奇特的变化:每个数字会变成自己跟后面一个数字的和(最后一个数字的后面一个数字是第一个),一旦某个位置的数字大于等于100就马上对100取模(比如某个位置变为103,就会自动变为3).现在给出这个魔力手环的构成,请你计算出使用k次魔力之后魔力手环的状态。

输入描述:

输入数据包括两行:

第一行为两个整数n(2 ≤ n ≤ 50)和k(1 ≤ k ≤ 2000000000),以空格分隔

第二行为魔力手环初始的n个数,以空格分隔。范围都在0至99.

输出描述:

输出魔力手环使用k次之后的状态,以空格分隔,行末无空格。

输入例子:

3 2

1 2 3

输出例子:

8 9 7

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
public class Main {	
/*
* 实习笔试 排队形 男生女生
*/
public static void swap(String s){
char[] chars=s.toCharArray();
int re1=0, re2=0, g=0, b=0;
for(int i=0; i<chars.length; i++){
if(chars[i]=='G'){
re1+=i-g;
g++;
}else{
re2+=i-b;
b++;
}
}
System.out.print(Math.min(re1, re2));
}

/*
* 魔力手环
*/
public static int[][] multi(int[][] a, int[][] b){
if(a.length==0||a[0].length==0||b.length==0||b[0].length==0||a[0].length!=b.length){
return null;
}
int m=a.length, n=b[0].length;
int[][] re=new int[m][n];
for(int i=0; i<m; i++){
for(int j=0; j<n; j++){
int sum=0;
for(int x=0; x<b.length; x++){
sum+=a[i][x]*b[x][j]%100;
}
re[i][j]=sum%100;
}
}
return re;
}

public static void print(int[][] matrix){
int[] nums=matrix[0];
System.out.print(nums[0]);
for(int i=1; i<nums.length; i++){
System.out.print(" "+nums[i]);
}
}

public static int[][] matrixPow(int[][] a, int n){
if(a.length==0||a[0].length==0||a.length!=a[0].length){
return null;
}
int len=a.length;
int[][] re=new int[len][len];
for(int i=0; i<len; i++){//单位阵
re[i][i]=1;
}
int[][] temp=a.clone();

while(n>0){
if((n&1)==1){
re=multi(re, temp);
}
temp=multi(temp, temp);
n>>=1;
}
return re;
}

private static void magicHoop(int[] nums, int n, int k){
int[][] a=new int[n][n];
a[0][0]=1;
a[0][n-1]=1;
for(int i=1; i<n; i++){
a[i][i]=1;
a[i][i-1]=1;
}
print(multi(new int[][]{nums}, matrixPow(a, k)));
}

public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n=sc.nextInt();
int k=sc.nextInt();
int[] nums=new int[n];
for(int i=0; i<n; i++){
nums[i]=sc.nextInt();
}
sc.close();
magicHoop(nums, n, k);
}
}

题中特意加了超过100后取模,看来就是为了防止溢出。