Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion pygit2/_libgit2/ffi.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,36 @@ class GitRepositoryC:
# def _from_c(cls, ptr: 'GitRepositoryC', owned: bool) -> 'Repository': ...
pass

class GitRemoteCallbacksC:
# TODO: Several Anys need filling in
version: int
sideband_progress: Any
completion: Any
credentials: Any
certificate_check: Any
transfer_progress: Any
update_tips: Any
pack_progress: Any
push_transfer_progress: Any
push_update_reference: Any
push_negotiation: Any
transport: Any
remote_ready: Any
payload: Any
resolve_url: Any
update_refs: Any

class GitFetchOptionsC:
# TODO: FetchOptions exist in _pygit2.pyi
# incomplete
depth: int
callbacks: GitRemoteCallbacksC
custom_headers: GitStrrayC

class GitPushOptionsC:
# TODO incomplete
callbacks: GitRemoteCallbacksC
custom_headers: GitStrrayC

class GitSubmoduleC:
pass
Expand Down Expand Up @@ -225,7 +251,11 @@ class GitRepositoryInitOptionsC:
origin_url: ArrayC[char]

class GitCloneOptionsC:
pass
# TODO: Several Anys need filling in
repository_cb: Any
repository_cb_payload: Any
remote_cb: Any
remote_cb_payload: Any

class GitPackbuilderC:
pass
Expand Down Expand Up @@ -256,6 +286,8 @@ def new(a: Literal['git_repository **']) -> _Pointer[GitRepositoryC]: ...
@overload
def new(a: Literal['git_remote **']) -> _Pointer[GitRemoteC]: ...
@overload
def new(a: Literal['git_remote_callbacks *']) -> GitRemoteCallbacksC: ...
@overload
def new(a: Literal['git_transaction **']) -> _Pointer[GitTransactionC]: ...
@overload
def new(a: Literal['git_repository_init_options *']) -> GitRepositoryInitOptionsC: ...
Expand All @@ -276,8 +308,12 @@ def new(a: Literal['git_blob **']) -> _Pointer[GitBlobC]: ...
@overload
def new(a: Literal['git_clone_options *']) -> GitCloneOptionsC: ...
@overload
def new(a: Literal['git_fetch_options *']) -> GitFetchOptionsC: ...
@overload
def new(a: Literal['git_merge_options *']) -> GitMergeOptionsC: ...
@overload
def new(a: Literal['git_push_options *']) -> GitPushOptionsC: ...
@overload
def new(a: Literal['git_blame_options *']) -> GitBlameOptionsC: ...
@overload
def new(a: Literal['git_annotated_commit **']) -> _Pointer[GitAnnotatedCommitC]: ...
Expand Down Expand Up @@ -364,6 +400,7 @@ def new(
a: Literal['char *[]'], b: list[Any]
) -> ArrayC[char_pointer]: ... # For string arrays
def addressof(a: object, attribute: str) -> _Pointer[object]: ...
def new_handle(a: T) -> _Pointer[T]: ...

class buffer(bytes):
def __init__(self, a: object) -> None: ...
Expand Down
92 changes: 69 additions & 23 deletions pygit2/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,15 @@
_Credentials = Username | UserPass | Keypair

if TYPE_CHECKING:
from pygit2._libgit2.ffi import GitProxyOptionsC
from pygit2._libgit2.ffi import (
GitCloneOptionsC,
GitFetchOptionsC,
GitProxyOptionsC,
GitPushOptionsC,
GitStrrayC,
)

from .remotes import PushUpdate, TransferProgress
from .remotes import PushUpdate, Remote, TransferProgress
#
# The payload is the way to pass information from the pygit2 API, through
# libgit2, to the Python callbacks. And back.
Expand All @@ -92,6 +98,9 @@ class Payload:
repository: Callable | None
remote: Callable | None
clone_options: Any
fetch_options: Any
push_options: Any
remote_callbacks: Any

def __init__(self, **kw: object) -> None:
for key, value in kw.items():
Expand Down Expand Up @@ -120,12 +129,10 @@ class RemoteCallbacks(Payload):
method, or if it's a constant value, pass the value to the constructor,
e.g. RemoteCallbacks(credentials=credentials).

You can as well pass the certificate the same way, for example:
RemoteCallbacks(certificate=certificate).
You can as well pass the certificate check callback the same way, for example:
RemoteCallbacks(certificate_check=certificate_check).
"""

push_options: Any

def __init__(
self,
credentials: _Credentials | None = None,
Expand Down Expand Up @@ -262,6 +269,19 @@ def push_update_reference(self, refname: str, message: str) -> None:
Rejection message from the remote. If None, the update was accepted.
"""

def custom_headers(self) -> list[str] | None:
"""
Custom headers callback. Override with your own function to return a
list of custom headers that should be used when connecting to, pushing
to, or fetching from the remote.

Example use case to authenticate with bearer tokens instead of username/password:

return [f"Authorization: Bearer {token}"]

Returns: list of header strings or None
"""


class CheckoutCallbacks(Payload):
"""Base class for pygit2 checkout callbacks.
Expand Down Expand Up @@ -352,7 +372,22 @@ def stash_apply_progress(self, progress: StashApplyProgress) -> None:


@contextmanager
def git_clone_options(payload, opts=None):
def git_custom_headers(
payload: RemoteCallbacks,
opts_custom_headers: Optional['GitStrrayC'] = None,
) -> Generator[StrArray, Any, None]:
custom_headers = payload.custom_headers() or None
with StrArray(custom_headers) as headers_array:
if opts_custom_headers is not None:
headers_array.assign_to(opts_custom_headers)
yield headers_array

Comment thread
beamerblvd marked this conversation as resolved.

@contextmanager
def git_clone_options(
payload: RemoteCallbacks,
opts: Optional['GitCloneOptionsC'] = None,
) -> Generator[RemoteCallbacks, Any, None]:
if opts is None:
opts = ffi.new('git_clone_options *')
C.git_clone_options_init(opts, C.GIT_CLONE_OPTIONS_VERSION)
Expand All @@ -374,7 +409,10 @@ def git_clone_options(payload, opts=None):


@contextmanager
def git_fetch_options(payload, opts=None):
def git_fetch_options(
payload: RemoteCallbacks | None,
opts: Optional['GitFetchOptionsC'] = None,
) -> Generator[RemoteCallbacks, Any, None]:
if payload is None:
payload = RemoteCallbacks()

Expand All @@ -392,15 +430,16 @@ def git_fetch_options(payload, opts=None):
handle = ffi.new_handle(payload)
opts.callbacks.payload = handle

# Give back control
payload.fetch_options = opts
payload._stored_exception = None
yield payload
with git_custom_headers(payload, opts.custom_headers):
# Give back control
payload.fetch_options = opts
payload._stored_exception = None
yield payload


@contextmanager
def git_proxy_options(
payload: object,
payload: 'Remote | RemoteCallbacks',
opts: Optional['GitProxyOptionsC'] = None,
proxy: None | bool | str = None,
) -> Generator['GitProxyOptionsC', None, None]:
Expand All @@ -414,20 +453,24 @@ def git_proxy_options(
elif type(proxy) is str:
opts.type = C.GIT_PROXY_SPECIFIED
# Keep url in memory, otherwise memory is freed and bad things happen
payload.__proxy_url = ffi.new('char[]', to_bytes(proxy)) # type: ignore[attr-defined]
opts.url = payload.__proxy_url # type: ignore[attr-defined]
payload.__proxy_url = ffi.new('char[]', to_bytes(proxy)) # type: ignore[union-attr]
opts.url = payload.__proxy_url # type: ignore[union-attr]
else:
raise TypeError('Proxy must be None, True, or a string')
yield opts


@contextmanager
def git_push_options(payload, opts=None):
def git_push_options(
payload: RemoteCallbacks | None,
opts: Optional['GitPushOptionsC'] = None,
) -> Generator[RemoteCallbacks, Any, None]:
if payload is None:
payload = RemoteCallbacks()

opts = ffi.new('git_push_options *')
C.git_push_options_init(opts, C.GIT_PUSH_OPTIONS_VERSION)
if opts is None:
opts = ffi.new('git_push_options *')
C.git_push_options_init(opts, C.GIT_PUSH_OPTIONS_VERSION)

# Plug callbacks
opts.callbacks.sideband_progress = C._sideband_progress_cb
Expand All @@ -448,14 +491,17 @@ def git_push_options(payload, opts=None):
handle = ffi.new_handle(payload)
opts.callbacks.payload = handle

# Give back control
payload.push_options = opts
payload._stored_exception = None
yield payload
with git_custom_headers(payload, opts.custom_headers):
# Give back control
payload.push_options = opts
payload._stored_exception = None
yield payload


@contextmanager
def git_remote_callbacks(payload):
def git_remote_callbacks(
payload: RemoteCallbacks | None,
) -> Generator[RemoteCallbacks, Any, None]:
if payload is None:
payload = RemoteCallbacks()

Expand Down
18 changes: 10 additions & 8 deletions pygit2/remotes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import utils
from ._pygit2 import Oid
from .callbacks import (
git_custom_headers,
git_fetch_options,
git_proxy_options,
git_push_options,
Expand Down Expand Up @@ -189,14 +190,15 @@ def connect(
"""
with git_proxy_options(self, proxy=proxy) as proxy_opts:
with git_remote_callbacks(callbacks) as payload:
err = C.git_remote_connect(
self._remote,
direction,
payload.remote_callbacks,
proxy_opts,
ffi.NULL,
)
payload.check_error(err)
with git_custom_headers(payload) as custom_headers:
err = C.git_remote_connect(
self._remote,
direction,
payload.remote_callbacks,
proxy_opts,
custom_headers.ptr,
)
payload.check_error(err)

def fetch(
self,
Expand Down
Loading
Loading