This commit is contained in:
William Ball 2021-06-23 16:03:53 -04:00
parent bc8b69c1c0
commit 359c25e031
2 changed files with 29 additions and 23 deletions

View file

@ -1,5 +1,4 @@
use gad::prelude::*; use gad::prelude::*;
// use gad::{arith::ArithAlgebra, core::CoreAlgebra, error::Result, prelude::GradientStore, Graph1};
use nalgebra::{SMatrix, SVector}; use nalgebra::{SMatrix, SVector};
use num::Complex; use num::Complex;
@ -7,7 +6,7 @@ type F = f64;
type Matrix2x2 = SMatrix<Complex<F>, 2, 2>; type Matrix2x2 = SMatrix<Complex<F>, 2, 2>;
fn sum( fn sum(
g: &mut Graph1, g: &mut GraphN,
ar: &Value<F>, ar: &Value<F>,
ai: &Value<F>, ai: &Value<F>,
br: &Value<F>, br: &Value<F>,
@ -17,7 +16,7 @@ fn sum(
} }
fn product( fn product(
g: &mut Graph1, g: &mut GraphN,
ar: &Value<F>, ar: &Value<F>,
ai: &Value<F>, ai: &Value<F>,
br: &Value<F>, br: &Value<F>,
@ -35,7 +34,7 @@ fn product(
} }
fn division( fn division(
g: &mut Graph1, g: &mut GraphN,
ar: &Value<F>, ar: &Value<F>,
ai: &Value<F>, ai: &Value<F>,
br: &Value<F>, br: &Value<F>,
@ -69,8 +68,8 @@ fn division(
Ok((f1, f2)) Ok((f1, f2))
} }
fn mobius_derivative(mat: Matrix2x2) -> Result<Complex<F>> { fn mobius_derivative(mat: Matrix2x2) -> Result<(Value<F>, Value<F>)> {
let mut g = Graph1::new(); let mut g = GraphN::new();
let x = g.variable(0.0); let x = g.variable(0.0);
let y = g.variable(0.0); let y = g.variable(0.0);
let a11r = g.constant(mat[(0, 0)].re); let a11r = g.constant(mat[(0, 0)].re);
@ -85,25 +84,43 @@ fn mobius_derivative(mat: Matrix2x2) -> Result<Complex<F>> {
let (numeratorr, numeratori) = { let (numeratorr, numeratori) = {
let (prodr, prodi) = product(&mut g, &x, &y, &a11r, &a11i)?; let (prodr, prodi) = product(&mut g, &x, &y, &a11r, &a11i)?;
sum(&mut g, &a12r, &a12i, &prodr, &prodi)? sum(&mut g, &a12r, &a12i, &prodr, &prodi)?
}; };
let (denominatorr, denominatori) = { let (denominatorr, denominatori) = {
let (prodr, prodi) = product(&mut g, &x, &y, &a21r, &a21i)?; let (prodr, prodi) = product(&mut g, &x, &y, &a21r, &a21i)?;
sum(&mut g, &a22r, &a22i, &prodr, &prodi)? sum(&mut g, &a22r, &a22i, &prodr, &prodi)?
}; };
let (resultr, resulti) = division(&mut g, &numeratorr, &numeratori, &denominatorr, &denominatori)?; let (resultr, resulti) = division(
&mut g,
&numeratorr,
&numeratori,
&denominatorr,
&denominatori,
)?;
let x = x.gid()?; let x = x.gid()?;
let gradients1 = g.evaluate_gradients(resultr.gid()?, 1f64)?; let one = g.constant(1.0);
let gradients2 = g.evaluate_gradients(resulti.gid()?, 1f64)?; let one2 = g.constant(1.0);
let gradients1 = g.compute_gradients(resultr.gid()?, one)?;
let gradients2 = g.compute_gradients(resulti.gid()?, one2)?;
let du_dx = gradients1.get(x).unwrap(); let du_dx = gradients1.get(x).unwrap();
let dv_dx = gradients2.get(x).unwrap(); let dv_dx = gradients2.get(x).unwrap();
Ok(Complex::new(*du_dx, -*dv_dx)) Ok(Complex::new(*du_dx.data(), -*dv_dx.data()))
}
fn next_order_derivative(
g: &mut GraphN,
expr: &GradientId<F>,
var: &GradientId<F>,
) -> Result<Value<F>> {
let dz = g.constant(1.0);
let dz_d = g.compute_gradients(*expr, dz)?;
Ok(dz_d.get(*var).unwrap().clone())
} }
fn power_method<const N: usize>( fn power_method<const N: usize>(
@ -149,21 +166,10 @@ fn main() {
// Complex::new(-2.0 / 12.0, -10.0 / 12.0), // Complex::new(-2.0 / 12.0, -10.0 / 12.0),
// ); // );
let mobius = Matrix2x2::new(
Complex::new(2.0, 0.0),
Complex::new(-3.0, 0.0),
Complex::new(4.0, 0.0),
Complex::new(0.0, -5.0),
);
println!("{}", mobius_derivative(mobius).unwrap());
// let r = Matrix2x2::new( // let r = Matrix2x2::new(
// Complex::new(-6.0 / 12.0, 8.0 / 12.0), // Complex::new(-6.0 / 12.0, 8.0 / 12.0),
// Complex::new(0.0, 11.0 / 12.0), // Complex::new(0.0, 11.0 / 12.0),
// Complex::new(0.0, 4.0 / 12.0), // Complex::new(0.0, 4.0 / 12.0),
// Complex::new(-6.0 / 12.0, -8.0 / 12.0), // Complex::new(-6.0 / 12.0, -8.0 / 12.0),
// ); // );
// secant_method(|x: F| x.cos() - x, 0.0, 1.0, std::f64::EPSILON, 1000);
} }