Gradient descent: Difference between revisions
Content added Content deleted
Thundergnat (talk | contribs) m (start as a draft, fix language name) |
|||
Line 111: | Line 111: | ||
The minimum is at x[0] = 0.10768224291553158, x[1] = -1.2233090211217854 |
The minimum is at x[0] = 0.10768224291553158, x[1] = -1.2233090211217854 |
||
</pre> |
</pre> |
||
===More optimal solution=== |
|||
<lang Typescript> |
|||
let data: number[][] = |
|||
[[32.5023452694530, 31.70700584656990], |
|||
[53.4268040332750, 68.77759598163890], |
|||
[61.5303580256364, 62.56238229794580], |
|||
[47.4756396347860, 71.54663223356770], |
|||
[59.8132078695123, 87.23092513368730], |
|||
[55.1421884139438, 78.21151827079920], |
|||
[52.2117966922140, 79.64197304980870], |
|||
[39.2995666943170, 59.17148932186950], |
|||
[48.1050416917682, 75.33124229706300], |
|||
[52.5500144427338, 71.30087988685030], |
|||
[45.4197301449737, 55.16567714595910], |
|||
[54.3516348812289, 82.47884675749790], |
|||
[44.1640494967733, 62.00892324572580], |
|||
[58.1684707168577, 75.39287042599490], |
|||
[56.7272080570966, 81.43619215887860], |
|||
[48.9558885660937, 60.72360244067390], |
|||
[44.6871962314809, 82.89250373145370], |
|||
[60.2973268513334, 97.37989686216600], |
|||
[45.6186437729558, 48.84715331735500], |
|||
[38.8168175374456, 56.87721318626850], |
|||
[66.1898166067526, 83.87856466460270], |
|||
[65.4160517451340, 118.59121730252200], |
|||
[47.4812086078678, 57.25181946226890], |
|||
[41.5756426174870, 51.39174407983230], |
|||
[51.8451869056394, 75.38065166531230], |
|||
[59.3708220110895, 74.76556403215130], |
|||
[57.3100034383480, 95.45505292257470], |
|||
[63.6155612514533, 95.22936601755530], |
|||
[46.7376194079769, 79.05240616956550], |
|||
[50.5567601485477, 83.43207142132370], |
|||
[52.2239960855530, 63.35879031749780], |
|||
[35.5678300477466, 41.41288530370050], |
|||
[42.4364769440556, 76.61734128007400], |
|||
[58.1645401101928, 96.76956642610810], |
|||
[57.5044476153417, 74.08413011660250], |
|||
[45.4405307253199, 66.58814441422850], |
|||
[61.8962226802912, 77.76848241779300], |
|||
[33.0938317361639, 50.71958891231200], |
|||
[36.4360095113868, 62.12457081807170], |
|||
[37.6756548608507, 60.81024664990220], |
|||
[44.5556083832753, 52.68298336638770], |
|||
[43.3182826318657, 58.56982471769280], |
|||
[50.0731456322890, 82.90598148507050], |
|||
[43.8706126452183, 61.42470980433910], |
|||
[62.9974807475530, 115.24415280079500], |
|||
[32.6690437634671, 45.57058882337600], |
|||
[40.1668990087037, 54.08405479622360], |
|||
[53.5750775316736, 87.99445275811040], |
|||
[33.8642149717782, 52.72549437590040], |
|||
[64.7071386661212, 93.57611869265820], |
|||
[38.1198240268228, 80.16627544737090], |
|||
[44.5025380646451, 65.10171157056030], |
|||
[40.5995383845523, 65.56230126040030], |
|||
[41.7206763563412, 65.28088692082280], |
|||
[51.0886346783367, 73.43464154632430], |
|||
[55.0780959049232, 71.13972785861890], |
|||
[41.3777265348952, 79.10282968354980], |
|||
[62.4946974272697, 86.52053844034710], |
|||
[49.2038875408260, 84.74269780782620], |
|||
[41.1026851873496, 59.35885024862490], |
|||
[41.1820161051698, 61.68403752483360], |
|||
[50.1863894948806, 69.84760415824910], |
|||
[52.3784462192362, 86.09829120577410], |
|||
[50.1354854862861, 59.10883926769960], |
|||
[33.6447060061917, 69.89968164362760], |
|||
[39.5579012229068, 44.86249071116430], |
|||
[56.1303888168754, 85.49806777884020], |
|||
[57.3620521332382, 95.53668684646720], |
|||
[60.2692143939979, 70.25193441977150], |
|||
[35.6780938894107, 52.72173496477490], |
|||
[31.5881169981328, 50.39267013507980], |
|||
[53.6609322616730, 63.64239877565770], |
|||
[46.6822286494719, 72.24725106866230], |
|||
[43.1078202191024, 57.81251297618140], |
|||
[70.3460756150493, 104.25710158543800], |
|||
[44.4928558808540, 86.64202031882200], |
|||
[57.5045333032684, 91.48677800011010], |
|||
[36.9300766091918, 55.23166088621280], |
|||
[55.8057333579427, 79.55043667850760], |
|||
[38.9547690733770, 44.84712424246760], |
|||
[56.9012147022470, 80.20752313968270], |
|||
[56.8689006613840, 83.14274979204340], |
|||
[34.3331247042160, 55.72348926054390], |
|||
[59.0497412146668, 77.63418251167780], |
|||
[57.7882239932306, 99.05141484174820], |
|||
[54.2823287059674, 79.12064627468000], |
|||
[51.0887198989791, 69.58889785111840], |
|||
[50.2828363482307, 69.51050331149430], |
|||
[44.2117417520901, 73.68756431831720], |
|||
[38.0054880080606, 61.36690453724010], |
|||
[32.9404799426182, 67.17065576899510], |
|||
[53.6916395710700, 85.66820314500150], |
|||
[68.7657342696216, 114.85387123391300], |
|||
[46.2309664983102, 90.12357206996740], |
|||
[68.3193608182553, 97.91982103524280], |
|||
[50.0301743403121, 81.53699078301500], |
|||
[49.2397653427537, 72.11183246961560], |
|||
[50.0395759398759, 85.23200734232560], |
|||
[48.1498588910288, 66.22495788805460], |
|||
[25.1284846477723, 53.45439421485050]]; |
|||
function lossFunction(arr0: number[], arr1: number[], arr2: number[]) { |
|||
let n: number = arr0.length; // Number of elements in X |
|||
//D_m = (-2/n) * sum(X * (Y - Y_pred)) # Derivative wrt m |
|||
let a: number = (-2 / n) * (arr0.map((a, i) => a * (arr1[i] - arr2[i]))).reduce((sum, current) => sum + current); |
|||
//D_c = (-2/n) * sum(Y - Y_pred) # Derivative wrt c |
|||
let b: number = (-2 / n) * (arr1.map((a, i) => (a - arr2[i]))).reduce((sum, current) => sum + current); |
|||
return [a, b]; |
|||
} |
|||
export const gradientDescentMain = () => { |
|||
// Building the model |
|||
let m: number = 0; |
|||
let c: number = 0; |
|||
let X_arr: number[]; |
|||
let Y_arr: number[]; |
|||
let Y_pred_arr: number[]; |
|||
let D_m: number = 0; |
|||
let D_c: number = 0; |
|||
let L: number = 0.00000001; // The learning Rate |
|||
let epochs: number = 10000000; // The number of iterations to perform gradient descent |
|||
//Initial guesses |
|||
for (let i = 0; i < epochs; i++) { |
|||
X_arr = data.map(function (value, index) { return value[0]; }); |
|||
Y_arr = data.map(function (value, index) { return value[1]; }); |
|||
// The current predicted value of Y |
|||
Y_pred_arr = X_arr.map((a) => ((m * a) + c)); |
|||
let all = lossFunction(X_arr, Y_arr, Y_pred_arr); |
|||
D_m = all[0]; |
|||
D_c = all[1]; |
|||
m = m - L * D_m; // Update m |
|||
c = c - L * D_c; // Update c |
|||
} |
|||
console.log("m: " + m + " c: " + c); |
|||
} |
|||
gradientDescentMain(); |
|||
</lang> |