-var b=Object.defineProperty;var u=(l,e)=>b(l,"name",{value:e,configurable:!0});import f from"../../../util/tensor.js";import v,{NeuralnetworkLayerException as g}from"./base.js";export default class d extends v{static{u(this,"BatchNormalizationLayer")}constructor({scale:e=1,offset:_=0,epsilon:n=1e-12,channel_dim:o=-1,input_mean:i,input_var:t,...r}){if(super(r),this._scale=null,typeof e=="string"?this._scalename=e:this._scale=e,this._offset=null,typeof _=="string"?this._offsetname=_:this._offset=_,this._epsilon=n,this._channel_dim=o,this._channel_dim!==-1&&this._channel_dim!==1)throw new g("Invalid channel dimension.");this._input_mean=i,this._input_var=t}get dependentLayers(){const e=[];return this._scalename&&e.push(this._scalename),this._offsetname&&e.push(this._offsetname),typeof this._input_mean=="string"&&e.push(this._input_mean),typeof this._input_var=="string"&&e.push(this._input_var),e}get mean(){return this._mean}get var(){return this._var}calc(e){const _=this._channel_dim<0?this._channel_dim+e.dimension:this._channel_dim,n=Array(e.dimension-_).fill(1);if(n[0]=e.sizes[_],this._scalename?(this._scale=this.graph.getNode(this._scalename).outputValue,this._scale.reshape(...n)):typeof this._scale=="number"?this._scale=new f(n,this._scale):Array.isArray(this._scale)&&(this._scale=f.fromArray(this._scale),this._scale.reshape(...n)),this._offsetname?(this._offset=this.graph.getNode(this._offsetname).outputValue,this._offset.reshape(...n)):typeof this._offset=="number"?this._offset=new f(n,this._offset):Array.isArray(this._offset)&&(this._offset=f.fromArray(this._offset),this._offset.reshape(...n)),typeof this._input_mean=="string"?(this._mean=this.graph.getNode(this._input_mean).outputValue,this._mean.reshape(...n)):Array.isArray(this._input_mean)&&(this._mean=f.fromArray(this._input_mean),this._mean.reshape(...n)),typeof this._input_var=="string"?(this._var=this.graph.getNode(this._input_var).outputValue,this._var.reshape(...n)):Array.isArray(this._input_var)&&(this._var=f.fromArray(this._input_var),this._var.reshape(...n)),!this._input_mean||!this._input_var){const i=Array.from({length:e.dimension},(h,c)=>c);i.splice(_,1);const t=i.reduce((h,c)=>h*e.sizes[c],1),r=e.reduce((h,c)=>h+c/t,0,i,!0);if(this._input_mean||(this._mean=r),!this._input_var){const h=e.copy();h.broadcastOperate(r,(a,s)=>(a-s)**2);const c=h.reduce((a,s)=>a+s/t,0,i,!0);this._var=c}}this._xc=e.copy(),this._xc.broadcastOperate(this._mean,(i,t)=>i-t),this._xh=this._xc.copy(),this._xh.broadcastOperate(this._var,(i,t)=>i/Math.sqrt(t+this._epsilon));const o=this._xh.copy();return o.broadcastOperate(this._scale,(i,t)=>i*t),o.broadcastOperate(this._offset,(i,t)=>i+t),o}grad(e){const _=this._channel_dim<0?this._channel_dim+e.dimension:this._channel_dim;this._bo=e;const n=this._bo.copy();n.broadcastOperate(this._scale,(a,s)=>a*s);const o=this._xc.copy();o.broadcastOperate(n,(a,s)=>a*s);const i=Array.from({length:e.dimension},(a,s)=>s);i.splice(_,1);const t=i.reduce((a,s)=>a*e.sizes[s],1),r=o.reduce((a,s)=>a+s/t,0,i,!0),h=this._xc.copy();h.broadcastOperate(this._var,(a,s)=>a/(s+this._epsilon)),h.broadcastOperate(r,(a,s)=>a*s),h.broadcastOperate(n,(a,s)=>s-a),h.broadcastOperate(this._var,(a,s)=>a/Math.sqrt(s+this._epsilon));const c=h.reduce((a,s)=>a+s/t,0,i,!0);if(h.broadcastOperate(c,(a,s)=>a-s),this._scalename||this._offsetname){const a={};if(this._scalename){const s=this._bo.reduce((m,p,y)=>m+p*this._xh.at(y)/t,0,i,!0);a[this._scalename]=s}if(this._offsetname){const s=this._bo.reduce((m,p)=>m+p/t,0,i,!0);a[this._offsetname]=s}return[h,a]}return h}update(e){if(this._scalename&&this._offsetname)return;const _=this._channel_dim<0?this._channel_dim+this._bo.dimension:this._channel_dim,n=Array.from({length:this._bo.dimension},(i,t)=>t);n.splice(_,1);const o=n.reduce((i,t)=>i*this._bo.sizes[t],1);if(!this._offsetname){const i=this._bo.reduce((t,r)=>t+r/o,0,n,!0);this._offset.broadcastOperate(e.delta("offset",i),(t,r)=>t-r)}if(!this._scalename){const i=this._bo.reduce((t,r,h)=>t+r*this._xh.at(h)/o,0,n,!0);this._scale.broadcastOperate(e.delta("scale",i),(t,r)=>t-r)}}toObject(){return{type:"batch_normalization",scale:this._scalename||this._scale.toArray?.()||this._scale,offset:this._offsetname||this._offset.toArray?.()||this._offset,epsilon:this._epsilon,channel_dim:this._channel_dim,input_mean:this._input_mean,input_var:this._input_var}}}d.registLayer();
0 commit comments