Rust常用併發示例代碼

記錄幾個常用的併發用法:


1、如何讓線程只創建1次

先看一段熟悉的java代碼:

void method1() {
    new Thread(() -> {
        while (true) {
            System.out.println(String.format("thread-id:%s,timestamp:%d",
                    Thread.currentThread().getId(), System.currentTimeMillis()));
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
            }
        }
    }).start();
}

如果method1()被多次調用,就會創建多個線程,如果希望不管調用多少次,只能有1個線程,在不使用線程池的前提下,有1個簡單的辦法:

AtomicBoolean flag = new AtomicBoolean(false);

void method1() {
    //AtomicBoolean保證線程安全,getAndSet是1個原子操作,method1只有第1次執行時,才能if判斷才能通過
    if (!flag.getAndSet(true)) {
        new Thread(() -> {
            while (true) {
                System.out.println(String.format("thread-id:%s,timestamp:%d",
                        Thread.currentThread().getId(), System.currentTimeMillis()));
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                }
            }
        }).start();
    }
}

在rust中也可以套用這個思路,完整代碼如下:


use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH};

//聲明1個全局靜態變量(AtomicBool能保證線程安全)
static FLAG: AtomicBool = AtomicBool::new(false);

fn method1() {
    //fetch_update類似java中的AtomicBoolean.getAndSet
    if FLAG
        .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |_| Some(true))
        .unwrap()
    {
        std::thread::spawn(move || loop {
            println!(
                "thread-id:{:?},timestamp:{}",
                thread::current().id(),
                timestamp()
            );
            thread::sleep(Duration::from_millis(1000));
        });
    }
}

//輔助方法,獲取系統時間戳(不用太關注這個方法)
fn timestamp() -> i64 {
    let start = SystemTime::now();
    let since_the_epoch = start
        .duration_since(UNIX_EPOCH)
        .expect("Time went backwards");
    let ms = since_the_epoch.as_secs() as i64 * 1000
        + (since_the_epoch.subsec_nanos() as f64 / 1_000_000.0) as i64;
    ms
}
fn main() {
    //調用2次
    method1();
    method1();

    //用1個死循環,防止main線束(僅演示用)
    loop {
        thread::sleep(Duration::from_millis(1000));
    }
}

輸出:

thread-id:ThreadId(2),timestamp:1662265684621
thread-id:ThreadId(2),timestamp:1662265685623
thread-id:ThreadId(2),timestamp:1662265686627
thread-id:ThreadId(2),timestamp:1662265687628
thread-id:ThreadId(2),timestamp:1662265688630
...

從輸出的線程id上看,2次method1()只創建了1個線程


2、如何讓線程執行完再繼續

fn main() {
    let mut thread_list = Vec::<thread::JoinHandle<()>>::new();
    for _i in 0..5 {
        let t = thread::spawn(|| {
            for n in 1..3 {
                println!("{:?}, n:{}", thread::current().id(), n);
                thread::sleep_ms(5);
            }
        });
        thread_list.push(t);
    }
    //運行後會發現,大概率只有下面這行會輸出,因爲main已經提前線束了,上面的線程沒機會執行,就被順帶着被幹掉了
    println!("main thread");
}

上面這段代碼,如果希望在main主線程結束前,讓所有創建出來的子線程執行完,可以使用join方法

fn main() {
    let mut thread_list = Vec::<thread::JoinHandle<()>>::new();
    for _i in 0..5 {
        let t = thread::spawn(|| {
            for n in 1..3 {
                println!("{:?}, n:{}", thread::current().id(), n);
                thread::sleep_ms(5);
            }
        });
        thread_list.push(t);
    }
 
    //將所有線程join,強制執行完後,才繼續
    for t in thread_list {
        t.join().unwrap();
    }
 
    println!("main thread");
}

3、線程互斥鎖

use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;

fn main() {
    //聲明1個互斥鎖Mutex,注意在多線程中使用時,必須套一層Arc
    let flag = Arc::new(Mutex::new(false));
    let mut handlers = vec![];
    for _ in 0..10 {
        let flag = Arc::clone(&flag);
        let handle = thread::spawn(move || {
            thread::sleep(Duration::from_millis(10));
            //只有1個線程會lock成功
            let mut b = flag.lock().unwrap();
            if !*b {
                //搶到鎖的,把標誌位改成true,其它線程就沒機會執行println
                *b = true;
                println!("sub\t=>\t{:?}", thread::current().id());
            }
        });
        handlers.push(handle);
    }
    for h in handlers {
        h.join().unwrap();
    }
    println!("main\t=>\t{:?}", thread::current().id());
}

上面的效果,9個子線程中,只會有1個搶到鎖,並輸出println,輸出類似下面這樣:

sub     =>      ThreadId(2)
main    =>      ThreadId(1)

4、線程之間發送數據

use std::sync::mpsc;
use std::sync::mpsc::{Receiver, Sender};
use std::thread;
use std::thread::JoinHandle;
use std::time::Duration;
 
fn main() {
    let (sender, receiver) = mpsc::channel();
    let t1 = send_something(sender);
    let t2 = receive_something(receiver);
 
    t1.join().unwrap();
    t2.join().unwrap();
}
 
/**
 * 線程發送消息測試
 */
fn send_something(tx: Sender<String>) -> JoinHandle<()> {
    thread::spawn(move || {
        //模擬先做其它業務處理
        thread::sleep(Duration::from_millis(100));
 
        let msg_list = vec![
            String::from("a"),
            String::from("b"),
            String::from("c"),
            //約定:\n是數據的結束符
            String::from("\n"),
        ];
 
        //發送一堆消息
        for msg in msg_list {
            tx.send(msg).unwrap();
        }
    })
}
 
/**
 * 線程收消息
 */
fn receive_something(rx: Receiver<String>) -> JoinHandle<()> {
    thread::spawn(move || loop {
        //try_recv 不會阻塞
        let s = rx.try_recv();
        if s.is_ok() {
            let msg = s.unwrap();
            if msg == "\n" {
                //約定:收到\n表示後面沒數據了,可以退出
                println!("end!");
                break;
            } else {
                println!("got msg:{}", msg);
            }
        }
        //模擬沒數據時乾點其它事情
        println!("do another thing!");
        thread::sleep(Duration::from_millis(100));
    })
}

輸出:

do another thing!
do another thing!
got msg:a
do another thing!
got msg:b
do another thing!
got msg:c
do another thing!
end!

5、線程池示例

先要引用threadpool的依賴

[dependencies]
threadpool="1.8.1"

然後就可以使用了:

use std::thread;
use std::time::Duration;
use threadpool::ThreadPool;

fn main() {
    let n_workers = 3;
    //創建1個名爲test-pool的線程池
    let pool = ThreadPool::with_name(String::from("test-pool"), n_workers);

    for _ in 0..10 {
        pool.execute(|| {
            println!(
                "{:?},{:?}",
                thread::current().id(),
                thread::current().name()
            );
            thread::sleep(Duration::from_millis(100));
        });
    }

    //待線程池中的所有任務都執行完
    pool.join();
}

輸出:

ThreadId(2),Some("test-pool")
ThreadId(3),Some("test-pool")
ThreadId(4),Some("test-pool")
ThreadId(2),Some("test-pool")
ThreadId(3),Some("test-pool")
ThreadId(4),Some("test-pool")
ThreadId(2),Some("test-pool")
ThreadId(3),Some("test-pool")
ThreadId(4),Some("test-pool")
ThreadId(2),Some("test-pool")

6、指定線程名稱

use std::thread;

fn main() {
    let t1 = thread::Builder::new()
        //子線程命名
        .name(format!("my-thread"))
        .spawn(|| {
            //打印子線程的id和name
            println!(
                "{:?},{:?}",
                thread::current().id(),
                thread::current().name()
            );
        })
        .unwrap();
    t1.join().unwrap();

    //打印主線程的id和name
    println!(
        "{:?},{:?}",
        thread::current().id(),
        thread::current().name()
    );
}

輸出:

ThreadId(2),Some("my-thread")
ThreadId(1),Some("main")

7、如何暫停/恢復線程運行

use std::time::Duration;

use std::thread;

fn main() {
    let (tx, rx) = std::sync::mpsc::channel();
    let t = thread::spawn(move || loop {
        //獲取當前時間的秒數
        let seconds = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs();
        println!("{}", seconds);
        //每當秒數爲5的倍數,就把自己暫停,同時對外發消息pause
        if seconds % 5 == 0 {
            tx.send("pause").unwrap();
            println!("\nwill be parked !!!");
            //將自己暫停
            thread::park();
        }
        thread::sleep(Duration::from_millis(1000));
    });

    //不斷收消息,發現是pause後,過3秒將線程t解封
    loop {
        let flag = rx.recv();
        if flag.is_ok() && flag.unwrap() == "pause" {
            thread::sleep(Duration::from_millis(3000));
            //解封t
            t.thread().unpark();
            println!("unparked !!!\n");
        }
    }
}

這樣就實現了一個簡易版的ScheudleThread,可以週期性的運行,運行效果:

1662278909
1662278910

will be parked !!!
unparked !!!

1662278914
1662278915

will be parked !!!
unparked !!!

1662278919
1662278920
...

8、信號量

推薦使用tokio的信號量實現

[dependencies]
tokio = { version = "1.21.0", features = ["full"] }

示例:

use std::sync::Arc;
use tokio::sync::Semaphore;

#[tokio::main]
async fn main() {
    let semaphore = Arc::new(Semaphore::new(1));
    println!("1-{:?}", semaphore);

    let _s = semaphore.clone().acquire_owned().await.unwrap();
    //消耗了1個信號量後,只剩下0
    println!("2-{:?}", semaphore);

    //此時再嘗試獲取信號量,會卡在這裏,直到有人把信號號釋放歸還
    let _s = semaphore.clone().acquire_owned().await.unwrap();
    println!("3-{:?}", semaphore);
    println!("done");
}

輸出:

1-Semaphore { ll_sem: Semaphore { permits: 1 } }
2-Semaphore { ll_sem: Semaphore { permits: 0 } }
...會卡在這裏

要歸還信號號,可以使用drop方法

use std::sync::Arc;
use tokio::sync::Semaphore;

#[tokio::main]
async fn main() {
    let semaphore = Arc::new(Semaphore::new(1));
    println!("1-{:?}", semaphore);

    let _s = semaphore.clone().acquire_owned().await.unwrap();
    println!("2-{:?}", semaphore);

    //信號號使用後,要記得歸原
    drop(_s);
    println!("歸原後-{:?}", semaphore);

    //只要剩餘信號量>0,就不會卡住了
    let _s = semaphore.clone().acquire_owned().await.unwrap();
    println!("3-{:?}", semaphore);
    println!("done");
}

輸出:

1-Semaphore { ll_sem: Semaphore { permits: 1 } }
2-Semaphore { ll_sem: Semaphore { permits: 0 } }
歸原後-Semaphore { ll_sem: Semaphore { permits: 1 } }
3-Semaphore { ll_sem: Semaphore { permits: 0 } }
done

9、條件變量Condvar

這個東西,要跟Mutex互斥鎖一起使用,不要問爲什麼,Condvar的wait方法簽名設計就是這樣的!
image
但其實使用過程中,Mutex的值完全可以跟Condvar沒任何關係,把官網的示例修改了下(注:可能沒啥實際意義,只是出於演示)

use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use std::time::Duration;

fn current_seconds() -> u64 {
    std::time::SystemTime::now()
        .duration_since(std::time::UNIX_EPOCH)
        .unwrap()
        .as_secs()
}

fn main() {
    println!("11111\t -> now:{}", current_seconds());
    let pair = Arc::new((Mutex::new(true), Condvar::new()));
    let pair2 = pair.clone();

    thread::spawn(move || {
        thread::sleep(Duration::from_secs(2));
        let &(_, ref cvar) = &*pair2;
        //喚醒被block的線程
        cvar.notify_one();
        println!("thread\t -> now:{}", current_seconds());
    });

    let &(_, ref cvar) = &*pair;
    println!("22222\t -> now:{}", current_seconds());
    //這裏會阻塞住,直到子線程裏notify_one通知
    //這裏可以看出cvar的wait中完全可以傳1個不相關的mutex!
    let no_use = Mutex::new(0);
    let _ = cvar.wait(no_use.lock().unwrap()).unwrap();

    println!("33333\t -> now:{}", current_seconds());
}

這裏main主線程在調用cvar.wait方法時會block住,直到子線程2秒後,cvar.notify_one()將其喚醒,輸出:

11111    -> now:1662285716
22222    -> now:1662285716
thread   -> now:1662285718
33333    -> now:1662285718

參考文章:

Rust語言聖經-多線程併發編程

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章