Jump to content

Gradient descent: Difference between revisions

m (start as a draft, fix language name)
Line 111:
The minimum is at x[0] = 0.10768224291553158, x[1] = -1.2233090211217854
</pre>
 
===More optimal solution===
<lang Typescript>
let data: number[][] =
[[32.5023452694530, 31.70700584656990],
[53.4268040332750, 68.77759598163890],
[61.5303580256364, 62.56238229794580],
[47.4756396347860, 71.54663223356770],
[59.8132078695123, 87.23092513368730],
[55.1421884139438, 78.21151827079920],
[52.2117966922140, 79.64197304980870],
[39.2995666943170, 59.17148932186950],
[48.1050416917682, 75.33124229706300],
[52.5500144427338, 71.30087988685030],
[45.4197301449737, 55.16567714595910],
[54.3516348812289, 82.47884675749790],
[44.1640494967733, 62.00892324572580],
[58.1684707168577, 75.39287042599490],
[56.7272080570966, 81.43619215887860],
[48.9558885660937, 60.72360244067390],
[44.6871962314809, 82.89250373145370],
[60.2973268513334, 97.37989686216600],
[45.6186437729558, 48.84715331735500],
[38.8168175374456, 56.87721318626850],
[66.1898166067526, 83.87856466460270],
[65.4160517451340, 118.59121730252200],
[47.4812086078678, 57.25181946226890],
[41.5756426174870, 51.39174407983230],
[51.8451869056394, 75.38065166531230],
[59.3708220110895, 74.76556403215130],
[57.3100034383480, 95.45505292257470],
[63.6155612514533, 95.22936601755530],
[46.7376194079769, 79.05240616956550],
[50.5567601485477, 83.43207142132370],
[52.2239960855530, 63.35879031749780],
[35.5678300477466, 41.41288530370050],
[42.4364769440556, 76.61734128007400],
[58.1645401101928, 96.76956642610810],
[57.5044476153417, 74.08413011660250],
[45.4405307253199, 66.58814441422850],
[61.8962226802912, 77.76848241779300],
[33.0938317361639, 50.71958891231200],
[36.4360095113868, 62.12457081807170],
[37.6756548608507, 60.81024664990220],
[44.5556083832753, 52.68298336638770],
[43.3182826318657, 58.56982471769280],
[50.0731456322890, 82.90598148507050],
[43.8706126452183, 61.42470980433910],
[62.9974807475530, 115.24415280079500],
[32.6690437634671, 45.57058882337600],
[40.1668990087037, 54.08405479622360],
[53.5750775316736, 87.99445275811040],
[33.8642149717782, 52.72549437590040],
[64.7071386661212, 93.57611869265820],
[38.1198240268228, 80.16627544737090],
[44.5025380646451, 65.10171157056030],
[40.5995383845523, 65.56230126040030],
[41.7206763563412, 65.28088692082280],
[51.0886346783367, 73.43464154632430],
[55.0780959049232, 71.13972785861890],
[41.3777265348952, 79.10282968354980],
[62.4946974272697, 86.52053844034710],
[49.2038875408260, 84.74269780782620],
[41.1026851873496, 59.35885024862490],
[41.1820161051698, 61.68403752483360],
[50.1863894948806, 69.84760415824910],
[52.3784462192362, 86.09829120577410],
[50.1354854862861, 59.10883926769960],
[33.6447060061917, 69.89968164362760],
[39.5579012229068, 44.86249071116430],
[56.1303888168754, 85.49806777884020],
[57.3620521332382, 95.53668684646720],
[60.2692143939979, 70.25193441977150],
[35.6780938894107, 52.72173496477490],
[31.5881169981328, 50.39267013507980],
[53.6609322616730, 63.64239877565770],
[46.6822286494719, 72.24725106866230],
[43.1078202191024, 57.81251297618140],
[70.3460756150493, 104.25710158543800],
[44.4928558808540, 86.64202031882200],
[57.5045333032684, 91.48677800011010],
[36.9300766091918, 55.23166088621280],
[55.8057333579427, 79.55043667850760],
[38.9547690733770, 44.84712424246760],
[56.9012147022470, 80.20752313968270],
[56.8689006613840, 83.14274979204340],
[34.3331247042160, 55.72348926054390],
[59.0497412146668, 77.63418251167780],
[57.7882239932306, 99.05141484174820],
[54.2823287059674, 79.12064627468000],
[51.0887198989791, 69.58889785111840],
[50.2828363482307, 69.51050331149430],
[44.2117417520901, 73.68756431831720],
[38.0054880080606, 61.36690453724010],
[32.9404799426182, 67.17065576899510],
[53.6916395710700, 85.66820314500150],
[68.7657342696216, 114.85387123391300],
[46.2309664983102, 90.12357206996740],
[68.3193608182553, 97.91982103524280],
[50.0301743403121, 81.53699078301500],
[49.2397653427537, 72.11183246961560],
[50.0395759398759, 85.23200734232560],
[48.1498588910288, 66.22495788805460],
[25.1284846477723, 53.45439421485050]];
 
function lossFunction(arr0: number[], arr1: number[], arr2: number[]) {
 
let n: number = arr0.length; // Number of elements in X
 
//D_m = (-2/n) * sum(X * (Y - Y_pred)) # Derivative wrt m
let a: number = (-2 / n) * (arr0.map((a, i) => a * (arr1[i] - arr2[i]))).reduce((sum, current) => sum + current);
//D_c = (-2/n) * sum(Y - Y_pred) # Derivative wrt c
let b: number = (-2 / n) * (arr1.map((a, i) => (a - arr2[i]))).reduce((sum, current) => sum + current);
return [a, b];
}
 
export const gradientDescentMain = () => {
 
// Building the model
let m: number = 0;
let c: number = 0;
let X_arr: number[];
let Y_arr: number[];
let Y_pred_arr: number[];
let D_m: number = 0;
let D_c: number = 0;
 
let L: number = 0.00000001; // The learning Rate
let epochs: number = 10000000; // The number of iterations to perform gradient descent
 
//Initial guesses
for (let i = 0; i < epochs; i++) {
X_arr = data.map(function (value, index) { return value[0]; });
Y_arr = data.map(function (value, index) { return value[1]; });
 
// The current predicted value of Y
Y_pred_arr = X_arr.map((a) => ((m * a) + c));
 
let all = lossFunction(X_arr, Y_arr, Y_pred_arr);
D_m = all[0];
D_c = all[1];
 
m = m - L * D_m; // Update m
c = c - L * D_c; // Update c
}
 
console.log("m: " + m + " c: " + c);
}
 
gradientDescentMain();
</lang>
Cookies help us deliver our services. By using our services, you agree to our use of cookies.