diff --git a/pygit2/_libgit2/ffi.pyi b/pygit2/_libgit2/ffi.pyi index a0710cae..690abf2b 100644 --- a/pygit2/_libgit2/ffi.pyi +++ b/pygit2/_libgit2/ffi.pyi @@ -87,10 +87,36 @@ class GitRepositoryC: # def _from_c(cls, ptr: 'GitRepositoryC', owned: bool) -> 'Repository': ... pass +class GitRemoteCallbacksC: + # TODO: Several Anys need filling in + version: int + sideband_progress: Any + completion: Any + credentials: Any + certificate_check: Any + transfer_progress: Any + update_tips: Any + pack_progress: Any + push_transfer_progress: Any + push_update_reference: Any + push_negotiation: Any + transport: Any + remote_ready: Any + payload: Any + resolve_url: Any + update_refs: Any + class GitFetchOptionsC: # TODO: FetchOptions exist in _pygit2.pyi # incomplete depth: int + callbacks: GitRemoteCallbacksC + custom_headers: GitStrrayC + +class GitPushOptionsC: + # TODO incomplete + callbacks: GitRemoteCallbacksC + custom_headers: GitStrrayC class GitSubmoduleC: pass @@ -225,7 +251,11 @@ class GitRepositoryInitOptionsC: origin_url: ArrayC[char] class GitCloneOptionsC: - pass + # TODO: Several Anys need filling in + repository_cb: Any + repository_cb_payload: Any + remote_cb: Any + remote_cb_payload: Any class GitPackbuilderC: pass @@ -256,6 +286,8 @@ def new(a: Literal['git_repository **']) -> _Pointer[GitRepositoryC]: ... @overload def new(a: Literal['git_remote **']) -> _Pointer[GitRemoteC]: ... @overload +def new(a: Literal['git_remote_callbacks *']) -> GitRemoteCallbacksC: ... +@overload def new(a: Literal['git_transaction **']) -> _Pointer[GitTransactionC]: ... @overload def new(a: Literal['git_repository_init_options *']) -> GitRepositoryInitOptionsC: ... @@ -276,8 +308,12 @@ def new(a: Literal['git_blob **']) -> _Pointer[GitBlobC]: ... @overload def new(a: Literal['git_clone_options *']) -> GitCloneOptionsC: ... @overload +def new(a: Literal['git_fetch_options *']) -> GitFetchOptionsC: ... +@overload def new(a: Literal['git_merge_options *']) -> GitMergeOptionsC: ... @overload +def new(a: Literal['git_push_options *']) -> GitPushOptionsC: ... +@overload def new(a: Literal['git_blame_options *']) -> GitBlameOptionsC: ... @overload def new(a: Literal['git_annotated_commit **']) -> _Pointer[GitAnnotatedCommitC]: ... @@ -364,6 +400,7 @@ def new( a: Literal['char *[]'], b: list[Any] ) -> ArrayC[char_pointer]: ... # For string arrays def addressof(a: object, attribute: str) -> _Pointer[object]: ... +def new_handle(a: T) -> _Pointer[T]: ... class buffer(bytes): def __init__(self, a: object) -> None: ... diff --git a/pygit2/callbacks.py b/pygit2/callbacks.py index 20bfad40..1072f8cf 100644 --- a/pygit2/callbacks.py +++ b/pygit2/callbacks.py @@ -79,9 +79,15 @@ _Credentials = Username | UserPass | Keypair if TYPE_CHECKING: - from pygit2._libgit2.ffi import GitProxyOptionsC + from pygit2._libgit2.ffi import ( + GitCloneOptionsC, + GitFetchOptionsC, + GitProxyOptionsC, + GitPushOptionsC, + GitStrrayC, + ) - from .remotes import PushUpdate, TransferProgress + from .remotes import PushUpdate, Remote, TransferProgress # # The payload is the way to pass information from the pygit2 API, through # libgit2, to the Python callbacks. And back. @@ -92,6 +98,9 @@ class Payload: repository: Callable | None remote: Callable | None clone_options: Any + fetch_options: Any + push_options: Any + remote_callbacks: Any def __init__(self, **kw: object) -> None: for key, value in kw.items(): @@ -120,12 +129,10 @@ class RemoteCallbacks(Payload): method, or if it's a constant value, pass the value to the constructor, e.g. RemoteCallbacks(credentials=credentials). - You can as well pass the certificate the same way, for example: - RemoteCallbacks(certificate=certificate). + You can as well pass the certificate check callback the same way, for example: + RemoteCallbacks(certificate_check=certificate_check). """ - push_options: Any - def __init__( self, credentials: _Credentials | None = None, @@ -262,6 +269,19 @@ def push_update_reference(self, refname: str, message: str) -> None: Rejection message from the remote. If None, the update was accepted. """ + def custom_headers(self) -> list[str] | None: + """ + Custom headers callback. Override with your own function to return a + list of custom headers that should be used when connecting to, pushing + to, or fetching from the remote. + + Example use case to authenticate with bearer tokens instead of username/password: + + return [f"Authorization: Bearer {token}"] + + Returns: list of header strings or None + """ + class CheckoutCallbacks(Payload): """Base class for pygit2 checkout callbacks. @@ -352,7 +372,22 @@ def stash_apply_progress(self, progress: StashApplyProgress) -> None: @contextmanager -def git_clone_options(payload, opts=None): +def git_custom_headers( + payload: RemoteCallbacks, + opts_custom_headers: Optional['GitStrrayC'] = None, +) -> Generator[StrArray, Any, None]: + custom_headers = payload.custom_headers() or None + with StrArray(custom_headers) as headers_array: + if opts_custom_headers is not None: + headers_array.assign_to(opts_custom_headers) + yield headers_array + + +@contextmanager +def git_clone_options( + payload: RemoteCallbacks, + opts: Optional['GitCloneOptionsC'] = None, +) -> Generator[RemoteCallbacks, Any, None]: if opts is None: opts = ffi.new('git_clone_options *') C.git_clone_options_init(opts, C.GIT_CLONE_OPTIONS_VERSION) @@ -374,7 +409,10 @@ def git_clone_options(payload, opts=None): @contextmanager -def git_fetch_options(payload, opts=None): +def git_fetch_options( + payload: RemoteCallbacks | None, + opts: Optional['GitFetchOptionsC'] = None, +) -> Generator[RemoteCallbacks, Any, None]: if payload is None: payload = RemoteCallbacks() @@ -392,15 +430,16 @@ def git_fetch_options(payload, opts=None): handle = ffi.new_handle(payload) opts.callbacks.payload = handle - # Give back control - payload.fetch_options = opts - payload._stored_exception = None - yield payload + with git_custom_headers(payload, opts.custom_headers): + # Give back control + payload.fetch_options = opts + payload._stored_exception = None + yield payload @contextmanager def git_proxy_options( - payload: object, + payload: 'Remote | RemoteCallbacks', opts: Optional['GitProxyOptionsC'] = None, proxy: None | bool | str = None, ) -> Generator['GitProxyOptionsC', None, None]: @@ -414,20 +453,24 @@ def git_proxy_options( elif type(proxy) is str: opts.type = C.GIT_PROXY_SPECIFIED # Keep url in memory, otherwise memory is freed and bad things happen - payload.__proxy_url = ffi.new('char[]', to_bytes(proxy)) # type: ignore[attr-defined] - opts.url = payload.__proxy_url # type: ignore[attr-defined] + payload.__proxy_url = ffi.new('char[]', to_bytes(proxy)) # type: ignore[union-attr] + opts.url = payload.__proxy_url # type: ignore[union-attr] else: raise TypeError('Proxy must be None, True, or a string') yield opts @contextmanager -def git_push_options(payload, opts=None): +def git_push_options( + payload: RemoteCallbacks | None, + opts: Optional['GitPushOptionsC'] = None, +) -> Generator[RemoteCallbacks, Any, None]: if payload is None: payload = RemoteCallbacks() - opts = ffi.new('git_push_options *') - C.git_push_options_init(opts, C.GIT_PUSH_OPTIONS_VERSION) + if opts is None: + opts = ffi.new('git_push_options *') + C.git_push_options_init(opts, C.GIT_PUSH_OPTIONS_VERSION) # Plug callbacks opts.callbacks.sideband_progress = C._sideband_progress_cb @@ -448,14 +491,17 @@ def git_push_options(payload, opts=None): handle = ffi.new_handle(payload) opts.callbacks.payload = handle - # Give back control - payload.push_options = opts - payload._stored_exception = None - yield payload + with git_custom_headers(payload, opts.custom_headers): + # Give back control + payload.push_options = opts + payload._stored_exception = None + yield payload @contextmanager -def git_remote_callbacks(payload): +def git_remote_callbacks( + payload: RemoteCallbacks | None, +) -> Generator[RemoteCallbacks, Any, None]: if payload is None: payload = RemoteCallbacks() diff --git a/pygit2/remotes.py b/pygit2/remotes.py index b940e294..265b1783 100644 --- a/pygit2/remotes.py +++ b/pygit2/remotes.py @@ -34,6 +34,7 @@ from . import utils from ._pygit2 import Oid from .callbacks import ( + git_custom_headers, git_fetch_options, git_proxy_options, git_push_options, @@ -189,14 +190,15 @@ def connect( """ with git_proxy_options(self, proxy=proxy) as proxy_opts: with git_remote_callbacks(callbacks) as payload: - err = C.git_remote_connect( - self._remote, - direction, - payload.remote_callbacks, - proxy_opts, - ffi.NULL, - ) - payload.check_error(err) + with git_custom_headers(payload) as custom_headers: + err = C.git_remote_connect( + self._remote, + direction, + payload.remote_callbacks, + proxy_opts, + custom_headers.ptr, + ) + payload.check_error(err) def fetch( self, diff --git a/test/test_remote.py b/test/test_remote.py index fda4fe3e..5f4e59e7 100644 --- a/test/test_remote.py +++ b/test/test_remote.py @@ -30,7 +30,8 @@ import pytest import pygit2 -from pygit2 import Remote, Repository +from pygit2 import Remote, RemoteCallbacks, Repository +from pygit2.ffi import ffi from pygit2.remotes import PushUpdate, TransferProgress from . import utils @@ -485,8 +486,6 @@ def test_push_non_fast_forward_commits_to_remote_fails( def test_push_options(origin: Repository, clone: Repository, remote: Remote) -> None: - from pygit2 import RemoteCallbacks - callbacks = RemoteCallbacks() remote.push(['refs/heads/master'], callbacks) remote_push_options = callbacks.push_options.remote_push_options @@ -515,8 +514,6 @@ def test_push_options(origin: Repository, clone: Repository, remote: Remote) -> def test_push_threads(origin: Repository, clone: Repository, remote: Remote) -> None: - from pygit2 import RemoteCallbacks - callbacks = RemoteCallbacks() remote.push(['refs/heads/master'], callbacks) assert callbacks.push_options.pb_parallelism == 1 @@ -562,3 +559,142 @@ def push_negotiation(self, updates: list[PushUpdate]) -> None: assert the_updates[0].dst == new_tip_id assert origin.branches['master'].target == new_tip_id + + +class HeaderCallbacks(RemoteCallbacks): + def custom_headers(self) -> list[str] | None: + return ['X-Other-One: foo', 'X-Other-Two: bar'] + + +def test_git_custom_headers_context_manager( + origin: Repository, + clone: Repository, + remote: Remote, +) -> None: + from pygit2.callbacks import git_custom_headers, git_fetch_options, git_push_options + + class EmptyHeaderCallbacks(RemoteCallbacks): + def custom_headers(self) -> list[str] | None: + return [] + + callbacks = RemoteCallbacks() + with git_custom_headers(callbacks) as headers: + assert headers.ptr == ffi.NULL + + callbacks = EmptyHeaderCallbacks() + with git_custom_headers(callbacks) as headers: + assert headers.ptr == ffi.NULL + + callbacks = HeaderCallbacks() + with git_custom_headers(callbacks) as headers: + ptr = headers.ptr + assert ptr != ffi.NULL + assert ptr.count == 2 # type: ignore[union-attr] + assert ffi.string(ptr.strings[0]) == b'X-Other-One: foo' # type: ignore[union-attr,index] + assert ffi.string(ptr.strings[1]) == b'X-Other-Two: bar' # type: ignore[union-attr,index] + + callbacks = RemoteCallbacks() + with git_fetch_options(callbacks) as payload: + assert payload.fetch_options.custom_headers.count == 0 + assert payload.fetch_options.custom_headers.strings == ffi.NULL + + callbacks = EmptyHeaderCallbacks() + with git_fetch_options(callbacks) as payload: + assert payload.fetch_options.custom_headers.count == 0 + assert payload.fetch_options.custom_headers.strings == ffi.NULL + + callbacks = HeaderCallbacks() + with git_fetch_options(callbacks) as payload: + assert payload.fetch_options.custom_headers.count == 2 + assert ( + ffi.string(payload.fetch_options.custom_headers.strings[0]) + == b'X-Other-One: foo' + ) + assert ( + ffi.string(payload.fetch_options.custom_headers.strings[1]) + == b'X-Other-Two: bar' + ) + + callbacks = RemoteCallbacks() + with git_push_options(callbacks) as payload: + assert payload.push_options.custom_headers.count == 0 + assert payload.push_options.custom_headers.strings == ffi.NULL + + callbacks = EmptyHeaderCallbacks() + with git_push_options(callbacks) as payload: + assert payload.push_options.custom_headers.count == 0 + assert payload.push_options.custom_headers.strings == ffi.NULL + + callbacks = HeaderCallbacks() + with git_push_options(callbacks) as payload: + assert payload.push_options.custom_headers.count == 2 + assert ( + ffi.string(payload.push_options.custom_headers.strings[0]) + == b'X-Other-One: foo' + ) + assert ( + ffi.string(payload.push_options.custom_headers.strings[1]) + == b'X-Other-Two: bar' + ) + + +def test_push_headers(origin: Repository, clone: Repository, remote: Remote) -> None: + callbacks = RemoteCallbacks() + remote.push(['refs/heads/master'], callbacks=callbacks) + assert callbacks.push_options.custom_headers.count == 0 + assert callbacks.push_options.custom_headers.strings == ffi.NULL + + callbacks = HeaderCallbacks() + remote.push(['refs/heads/master'], callbacks=callbacks) + assert callbacks.push_options.custom_headers.count == 2 + assert callbacks.push_options.custom_headers.strings != ffi.NULL + # strings pointed to by callbacks.push_options.custom_headers.strings[] are already freed + + # make sure the custom headers don't "stick around" + callbacks = RemoteCallbacks() + remote.push(['refs/heads/master'], callbacks=callbacks) + assert callbacks.push_options.custom_headers.count == 0 + assert callbacks.push_options.custom_headers.strings == ffi.NULL + + +def test_fetch_headers(origin: Repository, clone: Repository, remote: Remote) -> None: + callbacks = RemoteCallbacks() + remote.fetch(['refs/heads/master'], callbacks=callbacks) + assert callbacks.fetch_options.custom_headers.count == 0 + assert callbacks.fetch_options.custom_headers.strings == ffi.NULL + + callbacks = HeaderCallbacks() + remote.fetch(['refs/heads/master'], callbacks=callbacks) + assert callbacks.fetch_options.custom_headers.count == 2 + assert callbacks.fetch_options.custom_headers.strings != ffi.NULL + # strings pointed to by callbacks.fetch_options.custom_headers.strings[] are already freed + + # make sure the custom headers don't "stick around" + callbacks = RemoteCallbacks() + remote.fetch(['refs/heads/master'], callbacks=callbacks) + assert callbacks.fetch_options.custom_headers.count == 0 + assert callbacks.fetch_options.custom_headers.strings == ffi.NULL + + +@utils.requires_network +def test_connect_headers(testrepo: Repository) -> None: + # This is just a check that having custom headers doesn't cause errors. As far as I can tell, + # there's no way to assert that C.git_remote_connect was called with the headers except for + # having a remote server that expects the headers and fails without them. + + assert 1 == len(testrepo.remotes) + remote = testrepo.remotes[0] + + callbacks = RemoteCallbacks() + remote.connect(callbacks=callbacks) + refs = remote.list_heads(connect=False) + assert refs + # Check that a known ref is returned. + assert next(iter(r for r in refs if r.name == 'refs/tags/v0.28.2')) + + callbacks = HeaderCallbacks() + remote.connect(callbacks=callbacks) + refs = remote.list_heads(connect=False) + assert refs + # Check that a known ref is returned. + assert next(iter(r for r in refs if r.name == 'refs/tags/v0.28.2'))