import numpy as np import matplotlib.pyplot as plt np.set_printoptions(precision=2) N = 2 # number of elements split = np.sqrt(2)/2 #split = 0.5 N_lower = round(N*split) # number of elements in (0, split) print(N_lower, "elements below sqrt(2)/2") N_upper = N - N_lower # number of elements in (split, 1) print(N_upper, "elements above sqrt(2)/2") assert(N == N_lower + N_upper) x_lower = np.linspace(0, split, N_lower + 1) x_upper = np.linspace(split, 1, N_upper + 1) x = np.concatenate([x_lower, x_upper[1:]]) print(x) A = np.zeros((N + 1, N + 1)) for i in range(1, N + 1): h = x[i] - x[i - 1] lam = 1 if(x[i] > split): lam = 10 a_11 = lam/h a_12 = -lam/h a_21 = -lam/h a_22 = lam/h A[i - 1, i - 1] += a_11 A[i - 1, i] += a_12 A[i, i - 1] += a_21 A[i, i] += a_22 print("A =\n", A) # take dirichlet data into account u_g = np.zeros(N + 1) u_g[0] = 0 u_g[N] = 1 print("u_g =\n", u_g) # remove first and last row of A A_g = A[1:N, :] #print("A_g =\n", A_g) # assemble RHS with dirichlet data f = -A_g.dot(u_g) #print(f) # matrix for the inner nodes (excluding nodes with dirichlet bcs) A_0 = A[1:N, 1:N] #print(A_0) # solve for u_0 (free dofs) u_0 = np.linalg.solve(A_0, f) # assemble "u = u_0 + u_g" u = np.concatenate([[0], u_0, [1]]) print("u =\n", u) plt.plot(x, u, '-') plt.xlabel('x') plt.ylabel('u_h(x)') plt.grid() plt.show()