什麼是Context Manager

828
Context Manager
Context Manager

Context Manager,又稱作上下文管理器,是Python裡的概念,主要用於管理資源的分配和釋放,使用Context Manager可以確保資源在結束使用後會被正確的清除乾淨,避免資源浪費。而Context Manager通常會和withasync with語法一起使用,也方便我們閱讀這段程式碼背後是以什麼樣的形式處理。

Note: async 在Python中表示異步,要是對Async沒有概念的話可以參考這篇文章唷 -> 同步 (Synchronous) V.S 異步 (Asynchronous)

Context Manager的實現方法

from typing import Optional, Type
from types import TracebackType


Class MyContext:
    def __enter__(self) -> "MyContext":
        print("Enter!!!!")
        return self
    def __exit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_value: Optional[BaseException],
        exc_type: Optional[TracebackType],
    ) -> None:
        print("About to exit soon...")
        if exc_type is not None:
            # 如果在過程中發生錯誤則exc_type回報錯
            ...
        print("Entry point")

with MyContext() as my_context:
    # do something here
    # 結束後Context Manager會自動幫你釋放過程中佔用的資源
    

__enter__ 定義了當with MyContext 被呼叫時會優先被執行的程式,會回傳物件本身

__exit__ 在Context Manager調用期間結束後會執行的snippet

在Async的程式中一樣可以使用Context Manager,區別在於 __enter__ 現在變成 __aenter__,而__exit__變成__aexit__with MyContext as my_context變成async MyContext as my_context

Context Manager的使用時機

Context Manager的用途在於資源管理與釋放,所以舉凡需要調用到外部資源的功能都可以利用Context Manager來幫忙處理,比如:

1. Python內建的openbuiltin,用來檔案讀寫

with open(file_name, "w+") as f:
    f.write("Hello World")
    # do something

with open(json_config, "r") as f:
    config_data = json.load(f)
    # do something


# async example
# aiofiles 有內建的 lock來防止檔案在過程中被覆寫
import aiofiles

async with aiofiles.open(filepath, mode="r") as f:
    info = await r.read()

# 也可以自己使用lock來避免檔案被多個請求複寫
async def write_to_file(filename, data):
    async with file_lock:
        with open(filename, 'a') as file:
            file.write(data)

2. 對外部 Redis cache做操作

# Redis cache在 async下的範例

import asyncio
from redis.asyncio import Redis

class RedisCache:
    def __init__(self) -> None:
        self.redis: Optional[Redis] = None

    async def __aenter__(self) -> RedisCache:
        try:
            # 和redis cache交互的過程中設置timeout=30
            async with async_timeout.timeout(30):
                self.redis = await aioredis.from_url("redis://localhost", decode_responses=True, encoding="utf-8")
        except asyncio.exceptions.TimeoutError:
            print("Redis connection timeout!")

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_value: Optional[BaseException],
        exc_type: Optional[TracebackType],
    ) -> None:
        if self.redis:
            await self.redis.close()

    async def get(self, key: str) -> Optional[str]:
        return await self.redis.get(key)

    async def set(self, key: str, value: str) -> None:
        return await self.redis.set(key, value)

async def my_code() -> None:
    async with RedisCache() as redis:
        await redis.set("my_key", "my_value")
        value = await redis.get("my_key")
        print(value)
        
def main() -> None:
    asyncio.run(my_code())

上面的範例中,我們使用async with async_timeout.timeout(30) 來確保每次的操作都會在30秒之內結束,以避免資源變成orphan process,當然,同為外部資源的db也是同樣的道理。

異步程式中Context Manager與Mock測試範例

良好的開發者針對功能去做測試是非常基本且重要的,雖說寫unit test並不能防範所有可能出現的corner error,但是寫測試至少會幫助確保你的功能邏輯正如同你的假設,當我們的商業邏輯對外部有依賴時,為了讓測試聚焦在我們想測試的地方,當涉及外部依賴的測試就必須要使用測試框架裡的Mock來輔助我們。

假如今天我有一個遠端的RemoteClient必須要遠端連線並使用SSH的方式來獲取相關資訊,而其中有一部分的程式碼如下,他使用asyncssh的module來遠端連線到其他主機

class RemoteClient:
    def __init__(
        self,
        host_ip: str,
        port: int,
        user: str,
        password: str,
    ) -> None:
        ...

    async run(
        self,
        command: str,
    ) -> str:
        async with asyncssh.connect(
            host=self.host_ip,
            port=self.port,
            username=self.user,
            password=self.password,
        ) as ssh_conn:
            result = await conn.run(command)
            if result.stderr:
                output = result.stderr.decode("utf-8") if isinstance(result.stderr, bytes) else result.stderr
            elif result.stdout:
                output = result.stdout.decode("utf-8") if isinstance(result.stdout, bytes) else result.stdout

            if result.returncode:
                raise SubprocessError(output)
        return output

從上面可以看到,它在run方法中使用了Context Manager來處理異步的ssh連線,並在結束後釋放資源,那問題就來了,身為開發者,假設我想要去測試是否asyncssh 可以正確傳入我丟進去的參數,那我要怎麼測試?

這邊就要用到測試裡的Mockpatch,嚴格說起來Python裡的Mock混合了測試框架下StubMock的概念,定義中,針對有外部依賴的測試中,Stub物件定義了物件回傳的靜態結果,而Mock紀錄物件在交互過程中的次數、參數等屬性,而實際上在Python的Mock用法混合了兩者,你可以把它想像成一個萬用袋,參考如下

# Stub
>>> from unittest.mock import Mock
>>> a = Mock()
>>> a.bbb = 5
>>> a.bbb
5
# Mock
>>> def sum(num: int, num2: int) -> int:
...     return num + num2
... 
>>> a.sum = sum
>>> a.sum(5,10)
15

除此之外,patch的使用時機是當在測試時,你不想要真的去執行一自定義的函式,則你可以利用patch來自定義這個特定function或參數應該回傳的值,用法如下:

利用patch無論你傳的參數是什麼,你永遠都可以預期sum的回傳值會是自定義的值!

# myfoo.py

def sum(num1: int, num2: int) -> int:
    return num1 + num2


# test_myfoo.py
from unittest.mock import patch
from myfoo import sum


with patch("myfoo.sum", return_value=100):
    assert sum(500, 1000) == 100

初步瞭解Python裡的Mockpatch用法後就可以來看一下要怎麼測試 asyncssh.connect 會傳入我們想要的參數了,這一題要注意的點是,在使用patch時,也必須要使用Context Manager的方式來呼叫,由於asyncssh.connect是以Context Manager的形式來執行遠端連線,而在使用patch 也是以Context Manager的形式來呼叫with patch,為了測試asyncssh,等於是我們會在Context Manager(with patch)裡面再包一個Context Manager(async with asyncssh) ,為了讓測試碼可已傳入正確的型別,我們可以透過定義一個wrapper來包裹測試的Mock物件,並讓傳進去的Mock來在測試中做交互,總而言之,wrapper的目的是為了讓Mock物件能在測試時被視為是Context Manager的一種,直接看程式碼可能比較好懂

from unittest.mock import AsyncMock
import pytest


# for running asynchronous
pytestmark = pytest.mark.anyio

class AsyncContextManagerMock:
    def __init__(self, return_value: AsyncMock) -> None:
        self.return_value = return_value

    async def __aenter__(self) -> AsyncMock:
        return self.return_value

    async def __aexit__(self, *args: Any, **kwargs: Any) -> None:  # pylint: disable=unused-argument
        pass


class AsyncSSHRunMock(AsyncMock):
    # mock asyncssh 中會被呼叫的參數
    @property
    def returncode(self) -> int:
        return 0

    @property
    def stderr(self) -> None:
        return None

    @property
    def stdout(self) -> str:
        return "mock"

### 測試
async def test_remote_cmd_client_without_sudo_by_default() -> None:
    mock_async_ssh_conn = AsyncMock()

    client = RemoteCmdClient("127.0.0.1", 22, "urbanfish", "my_password")
    with patch(
        "asyncssh.connect", return_value=AsyncContextManagerMock(mock_async_ssh_conn)
    ) as asyncssh_mock:
        # 模擬執行的指令
        await client.run_command("echo", "'I am Groot'", ">", "/tmp/groot.txt")
        # asyncssh_mock 在Context Manager結束後就會被移除,所以只能在with的block裡做判斷
        asyncssh_mock.assert_called_once_with(
            "127.0.0.1",
            port=22,
            username="urbanfish",
            password="my_password",
        )

    # 注意這邊mock_async_ssh_conn是在Context Manager釋放資源之後,才能夠去確認Mock的run紀錄中查看是否在過程中有await
    mock_async_ssh_conn.run.assert_awaited_once_with(
        "echo 'I am Groot' > /tmp/groot.txt",
    )

上文可以看出我們讓asyncssh.connect 的回傳值等於AsyncContextManagerMock(mock_async_ssh_conn)AsyncContextManagerMock是一個Context Manager Wrapper,讓asyncssh_mock可以執行mock_async_ssh_conn裡的__aenter__,並在過程中回傳Mock物件並查看參數是否有傳入我們預期的參數,而mock_async_ssh_conn.run必須要等到Context Manager釋放資源之後我們才可以去查看run有被執行預期的指令,若你在with的block底下執行測試會報錯!透過上述的小技巧就可以成功測試參數的傳入以及預期command的呼叫。

補充

在Golang裡面一樣也有上下文管理器的觀念,用的就是在Go1.7之後引入的ctx Context,比如你可以使用context.WithCancel(context.Background)來處理一系列的Goroutine 任務,用法如下,使用context的好處是每個context可以管理同一系列下的多個任務,比如你打了一個API,一個API在後端中包含許多邏輯需要處理,可利用每個API都夾帶一個context的方式來確保任務的管理,多個API request會產生多個context來個別管理。

import (
    "context"
    "fmt"
    "time"
)

func main() {
    ctx, cancel := context.WithCancel(context.Background())
    go worker(ctx, "worker1")
    go worker(ctx, "worker2")
    go worker(ctx, "worker3")
    time.sleep(3 * time.Second)
    fmt.Println("Go home now!")
    cancel()
    time.sleep(3 * time.Second)
}

func worker(ctx context.Context, name string) {
    for {
        select {
        case <- ctx.Done()
            fmt.Println(name, "drop the Task")
            return
        default:
            fmt.Println(name, "on the Task...")
            time.sleep(1 * time.Second)
        }
     }
}

除了context之外,你也可以使用defer來手動確保檔案關閉,比如:

myFile := "myfile.txt"
file, err := os.Open(myFile)
	if err != nil {
		log.Printf("Cannot open %s: %v\n", myFile, err)
	}
	defer func(file *os.File) {
		_ = file.Close()
	}(file)

差別只在於Python的Context Manager會自動幫你管理資源,而Golang裡需要你自己手動去做設定,但當然相對的可控性會更佳。

這就是本次針對Context Manager的介紹啦,除了介紹Context Manager之外還額外介紹到底要怎麼去測試他們,希望對大家有幫助唷!😁

LEAVE A REPLY

Please enter your comment!
Please enter your name here