Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2This module implements the split-Bregman algorithm described in 

3T. Goldstein and S. Osher, SIAM J. Imaging Sci. 2, 323 (2009); 

4doi:10.1137/080725891 

5""" 

6 

7import numpy as np 

8from scipy.optimize import minimize 

9from typing import Any, Dict 

10 

11 

12def fit_split_bregman(A: np.ndarray, 

13 y: np.ndarray, 

14 mu: float = 1e-3, 

15 lmbda: float = 100, 

16 n_iters: int = 1000, 

17 tol: float = 1e-6) -> Dict[str, Any]: 

18 """ 

19 Determines the solution :math:`\\boldsymbol{x}` to the linear 

20 problem :math:`\\boldsymbol{A}\\boldsymbol{x}=\\boldsymbol{y}` using 

21 the split-Bregman algorithm described in T. Goldstein and S. Osher, 

22 SIAM J. Imaging Sci. 2, 323 (2009); doi:10.1137/080725891. 

23 The thus obtained parameters are returned in the form of a 

24 dictionary with a key named `parameters` 

25 

26 Parameters 

27 ---------- 

28 A 

29 fit matrix 

30 y 

31 target array 

32 mu 

33 sparseness parameter 

34 lmbda 

35 weight of additional L2-norm in split-Bregman 

36 n_iters 

37 maximal number of split-Bregman iterations 

38 tol 

39 convergence criterion iterative minimization 

40 """ 

41 

42 def _shrink(y: np.ndarray, alpha: float) -> np.ndarray: 

43 """ 

44 Shrinkage operator as defined in Eq. (11) of the paper by Nelson 

45 et al., Phys. Rev. B 87, 035125 (2013); doi:10.1103/PhysRevB.87.035125. 

46 """ 

47 return np.sign(y) * np.maximum(np.abs(y) - alpha, 0.0) 

48 

49 n_cols = A.shape[1] 

50 d = np.zeros(n_cols) 

51 b = np.zeros(n_cols) 

52 x = np.zeros(n_cols) 

53 

54 old_norm = 0.0 

55 

56 # Precompute for speed. 

57 AtA = np.dot(A.conj().transpose(), A) 

58 ftA = np.dot(y.conj().transpose(), A) 

59 ii = 0 

60 for i in range(n_iters): 60 ↛ 78line 60 didn't jump to line 78, because the loop on line 60 didn't complete

61 args = (A, y, mu, lmbda, d, b, AtA, ftA) 

62 res = minimize(_objective_function, x, args, method='BFGS', 

63 options={'disp': False}, 

64 jac=_objective_function_derivative) 

65 x = res.x 

66 

67 d = _shrink(mu*x + b, 1.0/lmbda) 

68 b = b + mu*x - d 

69 

70 new_norm = np.linalg.norm(x) 

71 ii = ii + 1 

72 

73 if abs(new_norm-old_norm) < tol: 

74 break 

75 

76 old_norm = new_norm 

77 

78 fit_results = {'parameters': x} 

79 return fit_results 

80 

81 

82def _objective_function(x: np.ndarray, A: np.ndarray, y: np.ndarray, 

83 mu: float, lmbda: float, d: np.ndarray, b: np.ndarray, 

84 AtA: np.ndarray, ftA: np.ndarray) -> np.ndarray: 

85 """ 

86 Returns the objective function to be minimized. 

87 

88 Parameters 

89 ----------- 

90 X 

91 fit matrix 

92 y 

93 target array 

94 mu 

95 the parameter that adjusts sparseness. 

96 lmbda 

97 Split Bregman parameter 

98 d 

99 same notation as Nelson et al. paper 

100 b 

101 same notation as Nelson et al. paper 

102 AtA 

103 sensing matrix transpose times sensing matrix. 

104 ftA 

105 np.dot(y.conj().transpose(), A) 

106 """ 

107 

108 error_vector = np.dot(A, x) - y 

109 

110 obj_function = 0.5*np.vdot(error_vector, error_vector) 

111 

112 if obj_function.imag > 0.0: 112 ↛ 113line 112 didn't jump to line 113, because the condition on line 112 was never true

113 raise RuntimeError( 

114 'Objective function contains non-zero imaginary part.)') 

115 

116 sparseness_correction = d - b - mu*x 

117 obj_function += 0.5*lmbda * \ 

118 np.vdot(sparseness_correction, sparseness_correction) 

119 

120 if obj_function.imag > 0.0: 120 ↛ 121line 120 didn't jump to line 121, because the condition on line 120 was never true

121 raise RuntimeError( 

122 'Objective function contains non-zero imaginary part.)') 

123 

124 return obj_function 

125 

126 

127def _objective_function_derivative(x: np.ndarray, 

128 A: np.ndarray, 

129 y: np.ndarray, 

130 mu: float, 

131 lmbda: float, 

132 d: np.ndarray, 

133 b: np.ndarray, 

134 AtA: np.ndarray, 

135 ftA: np.ndarray) -> np.ndarray: 

136 """ 

137 Returns the derivative of the objective function. 

138 

139 Parameters 

140 ----------- 

141 X 

142 fit matrix 

143 y 

144 target array 

145 mu 

146 the parameter that adjusts sparseness. 

147 lmbda 

148 Split Bregman parameter 

149 d 

150 same notation as Nelson, Hart paper 

151 b 

152 same notation as Nelson, Hart paper 

153 AtA 

154 sensing matrix transpose times sensing matrix. 

155 ftA 

156 np.dot(y.conj().transpose(), A) 

157 """ 

158 ret = np.squeeze(np.dot(x[np.newaxis, :], AtA) - 

159 ftA - lmbda*mu*(d - mu * x - b)) 

160 return ret