Actual source code: fsolvebaij.F
1: !
2: !
3: ! Fortran kernel for sparse triangular solve in the BAIJ matrix format
4: ! This ONLY works for factorizations in the NATURAL ORDERING, i.e.
5: ! with MatSolve_SeqBAIJ_4_NaturalOrdering()
6: !
7: #include include/finclude/petscdef.h
8: !
10: subroutine FortranSolveBAIJ4Unroll(n,x,ai,aj,adiag,a,b)
11: implicit none
12: MatScalar a(0:*)
13: PetscScalar x(0:*),b(0:*)
14: PetscInt n,ai(0:*),aj(0:*),adiag(0:*)
16: PetscInt i,j,jstart,jend,idx,ax,jdx
17: PetscScalar s1,s2,s3,s4
18: PetscScalar x1,x2,x3,x4
19: !
20: ! Forward Solve
21: !
23: x(0) = b(0)
24: x(1) = b(1)
25: x(2) = b(2)
26: x(3) = b(3)
27: idx = 0
28: do 20 i=1,n-1
29: jstart = ai(i)
30: jend = adiag(i) - 1
31: ax = 16*jstart
32: idx = idx + 4
33: s1 = b(idx)
34: s2 = b(idx+1)
35: s3 = b(idx+2)
36: s4 = b(idx+3)
37: do 30 j=jstart,jend
38: jdx = 4*aj(j)
39:
40: x1 = x(jdx)
41: x2 = x(jdx+1)
42: x3 = x(jdx+2)
43: x4 = x(jdx+3)
44: s1 = s1-(a(ax)*x1 +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
45: s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
46: s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
47: s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
48: ax = ax + 16
49: 30 continue
50: x(idx) = s1
51: x(idx+1) = s2
52: x(idx+2) = s3
53: x(idx+3) = s4
54: 20 continue
55:
56: !
57: ! Backward solve the upper triangular
58: !
59: do 40 i=n-1,0,-1
60: jstart = adiag(i) + 1
61: jend = ai(i+1) - 1
62: ax = 16*jstart
63: s1 = x(idx)
64: s2 = x(idx+1)
65: s3 = x(idx+2)
66: s4 = x(idx+3)
67: do 50 j=jstart,jend
68: jdx = 4*aj(j)
69: x1 = x(jdx)
70: x2 = x(jdx+1)
71: x3 = x(jdx+2)
72: x4 = x(jdx+3)
73: s1 = s1-(a(ax)*x1 +a(ax+4)*x2+a(ax+8)*x3 +a(ax+12)*x4)
74: s2 = s2-(a(ax+1)*x1+a(ax+5)*x2+a(ax+9)*x3 +a(ax+13)*x4)
75: s3 = s3-(a(ax+2)*x1+a(ax+6)*x2+a(ax+10)*x3+a(ax+14)*x4)
76: s4 = s4-(a(ax+3)*x1+a(ax+7)*x2+a(ax+11)*x3+a(ax+15)*x4)
77: ax = ax + 16
78: 50 continue
79: ax = 16*adiag(i)
80: x(idx) = a(ax)*s1 +a(ax+4)*s2+a(ax+8)*s3 +a(ax+12)*s4
81: x(idx+1) = a(ax+1)*s1+a(ax+5)*s2+a(ax+9)*s3 +a(ax+13)*s4
82: x(idx+2) = a(ax+2)*s1+a(ax+6)*s2+a(ax+10)*s3+a(ax+14)*s4
83: x(idx+3) = a(ax+3)*s1+a(ax+7)*s2+a(ax+11)*s3+a(ax+15)*s4
84: idx = idx - 4
85: 40 continue
86: return
87: end
88:
89: !
90: ! version that calls BLAS 2 operation for each row block
91: !
92: subroutine FortranSolveBAIJ4BLAS(n,x,ai,aj,adiag,a,b,w)
93: implicit none
94: MatScalar a(0:*),w(0:*)
95: PetscScalar x(0:*),b(0:*)
96: PetscInt n,ai(0:*),aj(0:*),adiag(0:*)
98: PetscInt i,j,jstart,jend,idx,ax,jdx,kdx
99: MatScalar s(0:3)
100: !
101: ! Forward Solve
102: !
104: x(0) = b(0)
105: x(1) = b(1)
106: x(2) = b(2)
107: x(3) = b(3)
108: idx = 0
109: do 20 i=1,n-1
110: !
111: ! Pack required part of vector into work array
112: !
113: kdx = 0
114: jstart = ai(i)
115: jend = adiag(i) - 1
116: if (jend - jstart .ge. 500) then
117: write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
118: endif
119: do 30 j=jstart,jend
120:
121: jdx = 4*aj(j)
122:
123: w(kdx) = x(jdx)
124: w(kdx+1) = x(jdx+1)
125: w(kdx+2) = x(jdx+2)
126: w(kdx+3) = x(jdx+3)
127: kdx = kdx + 4
128: 30 continue
130: ax = 16*jstart
131: idx = idx + 4
132: s(0) = b(idx)
133: s(1) = b(idx+1)
134: s(2) = b(idx+2)
135: s(3) = b(idx+3)
136: !
137: ! s = s - a(ax:)*w
138: !
139: call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)
140: ! call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)
142: x(idx) = s(0)
143: x(idx+1) = s(1)
144: x(idx+2) = s(2)
145: x(idx+3) = s(3)
146: 20 continue
147:
148: !
149: ! Backward solve the upper triangular
150: !
151: do 40 i=n-1,0,-1
152: jstart = adiag(i) + 1
153: jend = ai(i+1) - 1
154: ax = 16*jstart
155: s(0) = x(idx)
156: s(1) = x(idx+1)
157: s(2) = x(idx+2)
158: s(3) = x(idx+3)
159: !
160: ! Pack each chunk of vector needed
161: !
162: kdx = 0
163: if (jend - jstart .ge. 500) then
164: write(6,*) 'Overflowing vector FortranSolveBAIJ4BLAS()'
165: endif
166: do 50 j=jstart,jend
167: jdx = 4*aj(j)
168: w(kdx) = x(jdx)
169: w(kdx+1) = x(jdx+1)
170: w(kdx+2) = x(jdx+2)
171: w(kdx+3) = x(jdx+3)
172: kdx = kdx + 4
173: 50 continue
174: ! call sgemv('n',4,4*(jend-jstart+1),-1.0,a(ax),4,w,1,1.0,s,1)
175: call dgemv('n',4,4*(jend-jstart+1),-1.d0,a(ax),4,w,1,1.d0,s,1)
177: ax = 16*adiag(i)
178: x(idx) = a(ax)*s(0) +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
179: x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
180: x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
181: x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
182: idx = idx - 4
183: 40 continue
184: return
185: end
186:
188: !
189: ! version that does not call BLAS 2 operation for each row block
190: !
191: subroutine FortranSolveBAIJ4(n,x,ai,aj,adiag,a,b,w)
192: implicit none
193: MatScalar a(0:*)
194: PetscScalar x(0:*),b(0:*),w(0:*)
195: PetscInt n,ai(0:*),aj(0:*),adiag(0:*),ii,jj
197: PetscInt i,j,jstart,jend,idx,ax,jdx,kdx,nn
198: PetscScalar s(0:3)
199: !
200: ! Forward Solve
201: !
203: x(0) = b(0)
204: x(1) = b(1)
205: x(2) = b(2)
206: x(3) = b(3)
207: idx = 0
208: do 20 i=1,n-1
209: !
210: ! Pack required part of vector into work array
211: !
212: kdx = 0
213: jstart = ai(i)
214: jend = adiag(i) - 1
215: if (jend - jstart .ge. 500) then
216: write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
217: endif
218: do 30 j=jstart,jend
219:
220: jdx = 4*aj(j)
221:
222: w(kdx) = x(jdx)
223: w(kdx+1) = x(jdx+1)
224: w(kdx+2) = x(jdx+2)
225: w(kdx+3) = x(jdx+3)
226: kdx = kdx + 4
227: 30 continue
229: ax = 16*jstart
230: idx = idx + 4
231: s(0) = b(idx)
232: s(1) = b(idx+1)
233: s(2) = b(idx+2)
234: s(3) = b(idx+3)
235: !
236: ! s = s - a(ax:)*w
237: !
238: nn = 4*(jend - jstart + 1) - 1
239: do 100, ii=0,3
240: do 110, jj=0,nn
241: s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
242: 110 continue
243: 100 continue
245: x(idx) = s(0)
246: x(idx+1) = s(1)
247: x(idx+2) = s(2)
248: x(idx+3) = s(3)
249: 20 continue
250:
251: !
252: ! Backward solve the upper triangular
253: !
254: do 40 i=n-1,0,-1
255: jstart = adiag(i) + 1
256: jend = ai(i+1) - 1
257: ax = 16*jstart
258: s(0) = x(idx)
259: s(1) = x(idx+1)
260: s(2) = x(idx+2)
261: s(3) = x(idx+3)
262: !
263: ! Pack each chunk of vector needed
264: !
265: kdx = 0
266: if (jend - jstart .ge. 500) then
267: write(6,*) 'Overflowing vector FortranSolveBAIJ4()'
268: endif
269: do 50 j=jstart,jend
270: jdx = 4*aj(j)
271: w(kdx) = x(jdx)
272: w(kdx+1) = x(jdx+1)
273: w(kdx+2) = x(jdx+2)
274: w(kdx+3) = x(jdx+3)
275: kdx = kdx + 4
276: 50 continue
277: nn = 4*(jend - jstart + 1) - 1
278: do 200, ii=0,3
279: do 210, jj=0,nn
280: s(ii) = s(ii) - a(ax+4*jj+ii)*w(jj)
281: 210 continue
282: 200 continue
284: ax = 16*adiag(i)
285: x(idx) = a(ax)*s(0) +a(ax+4)*s(1)+a(ax+8)*s(2) +a(ax+12)*s(3)
286: x(idx+1)= a(ax+1)*s(0)+a(ax+5)*s(1)+a(ax+9)*s(2) +a(ax+13)*s(3)
287: x(idx+2)= a(ax+2)*s(0)+a(ax+6)*s(1)+a(ax+10)*s(2)+a(ax+14)*s(3)
288: x(idx+3)= a(ax+3)*s(0)+a(ax+7)*s(1)+a(ax+11)*s(2)+a(ax+15)*s(3)
289: idx = idx - 4
290: 40 continue
291: return
292: end
293: