理解 Rust impl Trait 機制

impl Trait 本身可看作 Rust 補充類型系統的修補,但正確理解其必要性需要從 Rust 語言更底層的問題談起。

impl Trait作為參數

impl Trait的出現,直接目的為:填補 closure 和 iterator 機制遺留的問題。

以 closure 為例。開發者可能要面臨將函數作為參數傳入另一函數的情景,通常大家會選擇函數指針fn來完成此任務。

問題如下:closure 也可以傳給函數指針fn嗎?

答案:有時可以,有時不行

不可傳的情況

fn main() {
    let f = |x: i32| -> i32 { x + 1 };
    let f_1 = |x: i32| -> i32 { f(f(x)) };
    let x = 1;
    println!("{}", pass_func_by_fn(f_1, x));
    println!("{}", pass_func_by_fn(normal_f_1, x));
}

fn pass_func_by_fn(f: fn(i32) -> i32, x: i32) -> i32 {
    f(x)
}

fn normal_f(x: i32) -> i32 {
    x + 1
}

fn normal_f_1(x: i32) -> i32 {
    normal_f(normal_f(x))
}

試圖編譯此段代碼,會在pass_func_by_fn(f_1, x))處得到編譯錯誤

error[E0308]: mismatched types
 --> main.rs:5:36
  |
3 |     let f_1 = |x: i32| -> i32 { f(f(x)) };
  |               --------------------------- the found closure
4 |     let x = 1;
5 |     println!("{}", pass_func_by_fn(f_1, x));
  |                                    ^^^ expected fn pointer, found closure
  |
  = note: expected fn pointer `fn(i32) -> i32`
                found closure `[closure@main.rs:3:15: 3:42 f:_]`

error: aborting due to previous error

For more information about this error, try `rustc --explain E0308`.

註釋掉此行後

fn main() {
    let f = |x: i32| -> i32 { x + 1 };
    let f_1 = |x: i32| -> i32 { f(f(x)) };
    let x = 1;
    // println!("{}", pass_func_by_fn(f_1, x));
    println!("{}", pass_func_by_fn(normal_f_1, x));
}

程序可以正確編譯並輸出

3

重點:fn是一個特定類型,而每個closure有自己獨一無二(且開發者不可知)的類型。

此處編譯器沒有計算出f_1的類型和fn(i32) -> i32類型的兼容性,故發生錯誤。照此說法,豈不是closure們基本無法傳給fn類型的參數了?錯。有些時候是可以傳的。

可傳的情況

fn main() {
    let f = |x: i32| -> i32 { x + 1 };
    let x = 1;
    println!("{}", pass_func_by_fn(f, x));
    println!("{}", pass_func_by_fn(normal_f, x));
}

fn pass_func_by_fn(f: fn(i32) -> i32, x: i32) -> i32 {
    f(x)
}

fn normal_f(x: i32) -> i32 {
    x + 1
}

編譯此段代碼不會發生任何錯誤,且運行後得到正確結果

2
2

是不是很奇妙。所以部分用戶流傳的「Rust 裡不可將 closure 傳給函數指針」說法為一謬誤。另一個事實是,儘管少數情景下可行,大部分 closure 確實無法被正確 pass 給fn指針。

首先我們要弄清楚 fn 和 Fn/FnMut/FnOnce 的區別:fn 是具體的類型,而 Fn/FnMut/FnOnce 是 trait;fn 是指向函數的指針,而 Fn/FnMut/FnOnce 的具體實例是函數本身,用類型來類比的話,即前者是&T,後者是 T。

fn 是函數指針類型,其要求指向的函數必須同時實現 Fn/FnMut/FnOnce 三個 trait,一般函數通常是符合的,但 closure 由於多會捕獲環境而可能不符合 Fn/FnMut,所以只有那些不捕獲環境的 closure 才能使用 fn 指針代指。

使用Generic來傳遞函數

為瞭解決fn指針可能無法應用於 closure 的問題,利用 closure 對象會根據情況自動實現Fn系列 trait(Fn/FnMut/FnOnce)這一機制,故可編寫 Generic 函數來實現功能。

fn main() {
    let f = |x: i32| -> i32 { x + 1 };
    let f_1 = |x: i32| -> i32 { f(f(x)) };
    let x = 1;
    println!("{}", pass_func_by_generic(f_1, x));
    println!("{}", pass_func_by_generic(normal_f_1, x));
}

fn pass_func_by_generic<T, U>(f: T, x: U) -> U
where
    T: Fn(U) -> U,
{
    f(x)
}

fn normal_f(x: i32) -> i32 {
    x + 1
}

fn normal_f_1(x: i32) -> i32 {
    normal_f(normal_f(x))
}

此段代碼可正確編譯,運行後輸出結果

3
3

Generic 機制使得 Rust 在編譯期為f_1normal_f_1分別生成對應的pass_func_by_generic函數,即 Static Dispatching。

使用作為參數的impl Trait來傳遞函數

首先看impl Trait的用法

fn main() {
    let f = |x: i32| -> i32 { x + 1 };
    let f_1 = |x: i32| -> i32 { f(f(x)) };
    let x = 1;
    println!("{}", pass_func_by_impl(f_1, x));
    println!("{}", pass_func_by_impl(normal_f_1, x));
}

fn pass_func_by_impl<U>(f: impl Fn(U) -> U, x: U) -> U {
    f(x)
}

fn normal_f(x: i32) -> i32 {
    x + 1
}

fn normal_f_1(x: i32) -> i32 {
    normal_f(normal_f(x))
}

此段代碼可以正確編譯,並輸出結果

3
3

要理解impl Trait的作用,還得回頭對比前文中的 Generic 寫法:

fn pass_func_by_generic<T, U>(f: T, x: U) -> U
where
    T: Fn(U) -> U,
{
    f(x)
}

注意 T 類型和 U 類型有嚴格邏輯聯繫,此段函數簽名中的 T 類型為信息冗餘,所以 T 類型可以消除,接下來試圖改寫pass_func_by_generic<T, U>函數到pass_func_by_generic<U>

(改寫步驟 1)(尚為錯誤代碼↓)

fn pass_func_by_generic<U>(f: Fn(U) -> U, x: U) -> U {
    f(x)
}

錯誤原因:Rust 2018 後,trait 對象須使用dyn Trait語法

(改寫步驟 2)(尚為錯誤代碼↓)

fn pass_func_by_generic<U>(f: dyn Fn(U) -> U, x: U) -> U {
    f(x)
}

trait 對象 f 的 size 在編譯期不確定,故只能改為 borrow

(改寫步驟 3)(正確代碼↓)

fn pass_func_by_generic<U>(f: &dyn Fn(U) -> U, x: U) -> U {
    f(x)
}

改寫pass_func_by_generic<U>函數的調用方法,使用 borrow 傳遞函數。

println!("{}", pass_func_by_generic(&f_1, x));
println!("{}", pass_func_by_generic(&normal_f_1, x));

程序可正確輸出

3
3

缺點:雖然達成了將pass_func_by_generic<T, U>改寫至pass_func_by_generic<U>的目的,但pass_func_by_generic從完全 Static Dispatching 變為部分 Dynamic Dispatching,引入了性能損耗。

為了幫助開發者既能保持完整的 Static Dispatching,又能消除函數簽名中信息冗餘的類型信息,只有引入impl Trait作為參數類型:

fn pass_func_by_impl<U>(f: impl Fn(U) -> U, x: U) -> U {
    f(x)
}

impl Trait作為 Static Dispatching 的補充工具,使得 Rust 可以在編譯期像pass_func_by_generic<T, U>一樣,為f_1normal_f_1生成各自對應的pass_func_by_impl<U>函數,而不引入任何Dynamic Dispatching機制,降低代碼性能。

impl Trait作為返回值

問題:如何在 Rust 中從函數返回一個函數

一般函數

fn main() {
    let x = 1;
    println!("{}", get_func()(x));
}

fn normal_f(x: i32) -> i32 {
    x + 1
}

fn get_func() -> fn(i32) -> i32 {
    normal_f
}

代碼正確編譯,運行後輸出

2

Closure

前文已經提到fn指針不一定和 closure 適配,所以需要利用Fn系列 Trait。

fn main() {
    let x = 1;
    println!("{}", get_func()(x));
}

fn get_func() -> Box<dyn Fn(i32) -> i32> {
    Box::new(|x| x + 1)
}

代碼正確編譯,運行後輸出

2

之所以這裡使用 Dynamic Dispatching,是因為在 Rust 中,如下寫法無法正確編譯(錯誤代碼↓)

fn get_func<T>() -> T
where
    T: Fn(i32) -> i32,
{
    |x| x + 1
}

如果使用 Generic Trait Type 作為函數返回值簽名,提供實現了此 Trait 的具體 Type 信息的權利在函數的 caller 身上,比如常用的parse函數:let _x: u32 = "14".parse()?;,函數自身無權把此 Type 給直接具體化

那不用 Trait,直接返回 closure 的具體類型呢?別忘了前面也提到過:每個closure有自己獨一無二(且開發者不可知)的類型。 所以開發者無法寫出作為返回值的 closure 的具體類型。

從棧上返回閉包

一種 work around 即前文提到的Box<dyn Fn(i32) -> i32>,副作用為:引入了 Dynamic Dispatching 和堆內存。

為了直接從棧上返回閉包,Rust 在 1.26 後提供impl Trait作為返回值的語法來覆蓋此情景:

fn main() {
    let x = 1;
    println!("{}", get_func()(x));
}

fn get_func() -> impl Fn(i32) -> i32 {
    |x| x + 1
}

此段代碼可正確編譯。

總結

impl Trait說穿,即為代指某個不具名、但在編譯期可確定的具體類型。相較於純 Generic 寫法,impl Trait覆蓋了 Generic 機制無法提供的功能,但其也有特殊限制:比如不可以在 Trait 的實現塊中使用。改寫現有的 Generic 函數到impl Trait亦可能會 break 現有代碼,因為 Generic 函數可以通過明確寫出範型參數的具體類型來代指具體函數,而impl Trait不可以。

建議在新代碼中多合理使用impl Trait


CC BY-SA 4.0

本文使用 CC BY-SA 4.0 授權

標籤:

分類:

更新時間: