ماتریس‌ها و مشتق: حل مسأله

توسط بهروز ودادیان در ۹ بهمن ۱۳۹۷math

عمده‌ی کارکرد مشتق‌گیری، یافتن نقاط اکسترموم است؛ اما نحوه‌ی استفاده از مشتق برای رسیدن به اکسترموم می‌تواند یا از طریق روش‌های iterative بهینه‌سازی باشد یا حل فرمال معادلات. برای اینکه از هر کدام از این موارد یک مثال دیده باشیم، اول مقدار ماتریس correlation را از روی تعدادی نمونه از یک توزیع گوسی تخمین می‌زنیم و بعدش گرادیان یک MLP را نسبت به وزن‌هایش محاسبه می‌کنیم.

تخمین ماتریس correlation

فرمول توزیع گوسی چند متغیره چیزی است که در زیر آمده:

p(x|μ,Σ)=det(۲πΣ)۱۲exp(۱۲(xμ)TΣ۱(xμ))

خب حالا اگر نمونه‌های x۱ تا xn از این توزیع داشته باشیم، چه مقادیری از μ و Σ احتمال رخداد این نمونه‌ها را بیشینه می‌کنند؟

با فرض اینکه نمونه‌ها از هم مستقل باشند، احتمال رخداد همه‌شان می‌شود:

p(x۱,x۲,xn|μ,Σ)=p(xi|μ,Σ)log(p(x۱,x۲,xn|μ,Σ))=log(p(xi|μ,Σ))

وقتی می‌شود با لگاریتم احتمال همان کاری را کرد که با احتمال می‌شود کرد، چرا سراغش نرویم. مخصوصاً برای توزیع گوسی که تابع نمایی نقش اساسی در آن دارد. خب، با جایگذاری رابطه‌ی توزیع احتمال در فرمول لگاریتم‌ها داریم:

log(p(x۱,x۲,xn|μ,Σ))=[log(det(۲πΣ)۱۲)۱۲(xiμ)TΣ۱(xiμ)]

باز هم ساده‌ترش کنیم می‌شود:

y=log(p(x۱,x۲,xn|μ,Σ))=n۲log(۲mπm)n۲log(det(Σ))۱۲(xiμ)TΣ۱(xiμ)

برای این ساده‌سازی، من اول عبارت دترمینان را که ربطی به اندیس جمع نداشت از جمع خارج کردم که طبیعتاً ضرب در تعداد دفعات جمع خوردن می‌شود (یعنی n). بعدش ۱۲ را از لگاریتم خارج کردم که بصورت ضرب ظاهر شود. در مرحله‌ی بعد، det(۲πΣ) را با استفاده از رابطه‌ی det(kA)=kndet(A) ساده کردم. در نهایت هم ضریب ۱۲ را از درون جمع به بیرون انتقال دادم که به دلیل خاصیت پخشی ضرب امکان‌پذیر است (توجه کرده‌اید که n تعداد نمونه‌ها و m تعداد عناصر هر نمونه است؟)

حالا کافیست دو تا مشتق بگیریم و برابر صفر قرار دهیم. یکی نسبت به μ(یک بردار) و دیگری نسبت به Σ (یک ماتریس). اول با بردار شروع می‌کنیم:

yμ=۱۲(Σ۱+ΣT)((xi)nμ)=۰μ=۱nxi

به همین سرعت محاسبه شد! می‌دانید چرا به Σ۱ و همزادش ΣT محل ندادم؟ چون طبق فرمول، این ماتریس‌ها معکوس‌پذیرند و فقط در یک حالت حاصل ضرب آن‌ها در یک بردار برابر صفر می‌شود و آن وقتی است که خود آن بردار صفر باشد؛ یعنی (xi)nμ=۰. اما مشتق دوم که جذاب‌تر هم هست:

yΣ=yΣ۱Σ۱ΣyΣ۱=n۲vecT(ΣT)۱۲((xiμ)T(xiμ)T)Σ۱Σ=(ΣTΣ۱)yΣ=n۲vecT(ΣT)(ΣTΣ۱)+۱۲((xiμ)T(xiμ)T)(ΣTΣ۱)

از طول و درازای فرمول بالا نترسید، در دلش چیزی نیست. کافیست از رابطه‌ی زیر برای ساده‌سازی استفاده کنیم:

vec(AXB)=(BTA)vec(X)

و یا معادلش:

vecT(AXB)=vecT(X)(BAT)

با جایگذاری به نتیجه‌ی زیر می‌رسیم:

yΣ=n۲vecT(ΣT)+۱۲[((xiμ)T(xiμ)T)](ΣTΣ۱)=۰

یا به عبارتی:

vecT(ΣT)(ΣTΣ۱)۱=۱n((xiμ)T(xiμ)T)=۰

باز هم همان رابطه‌ی vec(AXB). در پست اول که گفتم، این رابطه در مشتق‌گیری خیلی کاربرد دارد:

vecT(ΣT)(ΣTΣ۱)۱=vecT(ΣT)(ΣTΣ)=vecT(ΣTΣTΣT)=vecT(ΣT)

اورکا، اورکا:

vecT(ΣT)=۱n((xiμ)T(xiμ)T)

باز هم vec(AXB)! اگر X یک عدد اسکالر، A یک بردار سطری (n×۱) و B یک بردار ستونی باشد، چه رابطه‌ای حاصل می‌شود؟

vec(αab)=vec(aαb)=(bTa)vec(α)=α(bTa)vecT(αab)=vecT(aαb)=vecT(α)(baT)=α(baT)

زیبا نیست؟ با این حساب:

(xiμ)T(xiμ)T=vecT((xiμ)(xiμ)T)vecT(ΣT)=۱nvecT((xiμ)(xiμ)T)vecT(ΣT)=vecT(۱n((xiμ)(xiμ)T))Σ=ΣT=۱n((xiμ)(xiμ)T)

در اینجا باید بپرسید که چرا Σ=ΣT؟ جوابش این است که وقتی ΣT=۱n((xiμ)(xiμ)T) باشد، ترانهاده‌اش برابر خودش می‌شود.

مشتق نسبت به وزن‌های MLP

لایه‌های یک شبکه‌ی MLP را می‌توان بصورت زیر نشان داد:

yn+۱=ϕ(Wnyn+bn)

بسته به این که MLP مورد نظر ما چند لایه باشد، وقتی می‌خواهیم آموزشش بدهیم، خروجی لایه‌ی mم را با نتایج دلخواه مقایسه می‌کنیم و خطا را محاسبه می‌کنیم. فرض کنیم که معیار خطای کمترین متوسط مربعات را در نظر داشته باشیم:

j(W۰,b۰,W۱,b۱,Wm۱,bm۱)=۱۲ymy۲=۱۲(ymy)T(ymy)

از ماشین‌آلاتی که برای مشتق‌گیری ساختیم برای محاسبه‌ی jWi استفاده می‌کنیم. محاسبه‌ی مشتق نسبت به بایاس‌ها هم کاملاً مشابه است:

jWi=jym(m۱iyj+۱yj)yiWi

حالا کافیست هر جزء را محاسبه کنیم.

jym=(ymy)Tyj+۱yj=diag(vec(ϕ(Wjyj+bj)))WjyiWi=(yiTI)

یافتن اینکه این‌ها را چطور با استفاده از مطالبی که در سه پُست گذشته ارائه کردم محاسبه کرده‌ام به خود شما واگذار می‌کنم که بیابید (این خودش تمرین خوبی است.) اما مباحث از اینجا به بعد شیرین‌تر می‌شود. برای اینکه بتوانم ادامه بدهم باید یک رابطه‌ی جدید را معرفی کنم.

ab=diag(a)b(ab)T=bTdiag(a)

علامت نمایشگر ضرب هادامارد یا عنصربه‌عنصر است. از طرفی وقتی Wjyj+bj یک بردار است، ϕ هم رویش اعمال شود، باز هم خروجی یک بردار است، پس می‌شود کل مشتق بالا را اینگونه نوشت:

jWi=(ymy)T(diag(ϕ(Wjyj+bj))Wj)(yiTI)

بد نیست ببینیم، آن در وسط در واقع چه بلایی سر مقدار خطا (ymy) می‌آورد. از خود خطا شروع می‌کنیم:

ejT=(ymy)Tdiag(ϕ(Wjyj+bj))Wj=((ymy)ϕ(Wjyj+bj))TWj

انگار هر عنصر از خطا، با مقدار مشتق تابع ϕ به ازای همان عنصر وزن‌دهی می‌شود و بعدش بردار به دست آمده از سمت چپ در ماتریس وزن لایه‌ی قبلی ضرب می‌شود. این همان پس انتشار خطاست. هر عنصری از خطا، با وزن مربوطه، به یکی از گره‌های لایه‌ی زیرین مرتبط می‌شود. همین عملیات را می‌شود ادامه داد:

ej۱T=ejTdiag(ϕ(Wj۱yj۱+bj۱))Wj۱=(ejϕ(Wjyj+bj))TWj۱

باز هم مقدار خطای لایه‌ی بالا با مشتق ϕ در لایه‌ی زیرین وزن‌دهی شده و به لایه‌ی پایین‌تر منتشر می‌شود. اگر با همین فرمان پیش برویم به خطای لایه‌ی مورد نظر یعنی ei می‌رسیم:

jWi=eiT(yiI)=vecT(eiyiT)

تمرین خوبی است که ببینید چرا رابطه‌ی بالا برقرار است. بررسی اینکه vec(ei) چه فرمی دارد در این مورد نقطه‌ی شروع است. بگذریم.

در نهایت این مقدار مشتق چه کاربردی دارد؟ از آن در بهینه‌سازی و رفتن به سمت مقدار آموزش دیده‌ی وزن‌ها استفاده می‌کنیم. در روش‌های بهینه‌سازی مبتنی بر گرادیان، همیشه مقدار گرادیان خطا نسبت به پارامترها با نسبتی منفی جمع می‌شود. این اساس همه‌ی روش‌های بهینه‌سازی مبتنی بر گرادیان است. اما در اینجا گرادیان نسبت به Wi ابعادی مشابه با Wi ندارد؛ چگونه باید این دو را جمع بزنیم؟

اگر از ابتدا سراغ تنسورها رفته بودیم، این مشکل پیش نمی‌آمد. مقدار مشتق هم ابعادی برابر با خود Wi می‌داشت (صد البته چون j یک عدد اسکالر است.) حالا هم اتفاق خاصی نیفتاده. یادتان هست که برای تعریف اینگونه مشتق‌گیری که در این چهار پُست با هم مرورش کردیم، از ابتدا قرار گذاشتیم که هر ماتریس را با یک زوج مرتب (vec(A),(m,n)) نمایش بدهیم. الآن هم که مشتق را داریم، ابعاد ماتریس‌های Wi را هم که داریم. کافیست مشتق محاسبه شده را با ابعاد ماتریس اولیه باز نمایی کنیم. تعداد عناصر مشتق طبق تعریف یکی است.

نکته‌ی جالب ماجرا اینجاست که تا الآن که این پست را می‌نویسم، بدون استثنا، باز مرتب‌سازی مشتق برای من برابر با نگاه کردن به درون عملگر vec بوده است. همین مثال MLP را نگاه کنید:

jWi=eiT(yiI)=vecT(eiyiT)

مقدار eiyiT ابعادی دقیقاً برابر ماتریس اولیه دارد!

نتایج اخلاقی

در نهایت این نتایج را به عنوان خلاصه تقدیم می‌کنم:

۱-برای گرفتن مشتق نسبت به ماتریس نیازی به محاسبات تنسوری نیست ۲-مشتق‌گرفتن برای ورزش ذهنی خوب است ۳-اصولاً اکثر اوقات نیاز نیست خودتان مشتق نسبت به ماتریس بگیرید!