Skip to content

Commit

Permalink
feat: add validate function to documents to have less verbose api (#1058
Browse files Browse the repository at this point in the history
)

* feat: add validate on Text document

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: complete test

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* feat: add image shortcut

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* feat: add video

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* docs: we to u

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* feat: add mesh

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* feat: add audio

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>
  • Loading branch information
samsja committed Jan 27, 2023
1 parent 922f182 commit 4311bcc
Show file tree
Hide file tree
Showing 12 changed files with 291 additions and 7 deletions.
26 changes: 25 additions & 1 deletion docarray/documents/audio.py
@@ -1,9 +1,19 @@
from typing import Optional, TypeVar
from typing import Any, Optional, Type, TypeVar, Union

import numpy as np

from docarray.base_document import BaseDocument
from docarray.typing import AnyEmbedding, AudioUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.audio.audio_tensor import AudioTensor

try:
import torch

torch_available = True
except ImportError:
torch_available = False

T = TypeVar('T', bound='Audio')


Expand Down Expand Up @@ -76,3 +86,17 @@ class MultiModalDoc(Document):
url: Optional[AudioUrl]
tensor: Optional[AudioTensor]
embedding: Optional[AnyEmbedding]

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch_available and isinstance(value, torch.Tensor)
):
value = cls(tensor=value)

return super().validate(value)
28 changes: 27 additions & 1 deletion docarray/documents/image.py
@@ -1,7 +1,19 @@
from typing import Optional
from typing import Any, Optional, Type, TypeVar, Union

import numpy as np

from docarray.base_document import BaseDocument
from docarray.typing import AnyEmbedding, AnyTensor, ImageUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor

T = TypeVar('T', bound='Image')

try:
import torch

torch_available = True
except ImportError:
torch_available = False


class Image(BaseDocument):
Expand Down Expand Up @@ -67,3 +79,17 @@ class MultiModalDoc(BaseDocument):
url: Optional[ImageUrl]
tensor: Optional[AnyTensor]
embedding: Optional[AnyEmbedding]

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch_available and isinstance(value, torch.Tensor)
):
value = cls(tensor=value)

return super().validate(value)
13 changes: 12 additions & 1 deletion docarray/documents/mesh.py
@@ -1,8 +1,10 @@
from typing import Optional
from typing import Any, Optional, Type, TypeVar, Union

from docarray.base_document import BaseDocument
from docarray.typing import AnyEmbedding, AnyTensor, Mesh3DUrl

T = TypeVar('T', bound='Mesh3D')


class Mesh3D(BaseDocument):
"""
Expand Down Expand Up @@ -77,3 +79,12 @@ class MultiModalDoc(BaseDocument):
vertices: Optional[AnyTensor]
faces: Optional[AnyTensor]
embedding: Optional[AnyEmbedding]

@classmethod
def validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
return super().validate(value)
28 changes: 27 additions & 1 deletion docarray/documents/point_cloud.py
@@ -1,7 +1,19 @@
from typing import Optional
from typing import Any, Optional, Type, TypeVar, Union

import numpy as np

from docarray.base_document import BaseDocument
from docarray.typing import AnyEmbedding, AnyTensor, PointCloud3DUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor

try:
import torch

torch_available = True
except ImportError:
torch_available = False

T = TypeVar('T', bound='PointCloud3D')


class PointCloud3D(BaseDocument):
Expand Down Expand Up @@ -75,3 +87,17 @@ class MultiModalDoc(BaseDocument):
url: Optional[PointCloud3DUrl]
tensor: Optional[AnyTensor]
embedding: Optional[AnyEmbedding]

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch_available and isinstance(value, torch.Tensor)
):
value = cls(tensor=value)

return super().validate(value)
13 changes: 12 additions & 1 deletion docarray/documents/text.py
@@ -1,9 +1,11 @@
from typing import Optional
from typing import Any, Optional, Type, TypeVar, Union

from docarray.base_document import BaseDocument
from docarray.typing import TextUrl
from docarray.typing.tensor.embedding import AnyEmbedding

T = TypeVar('T', bound='Text')


class Text(BaseDocument):
"""
Expand Down Expand Up @@ -68,3 +70,12 @@ class MultiModalDoc(BaseDocument):
text: Optional[str] = None
url: Optional[TextUrl] = None
embedding: Optional[AnyEmbedding] = None

@classmethod
def validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
if isinstance(value, str):
value = cls(text=value)
return super().validate(value)
26 changes: 25 additions & 1 deletion docarray/documents/video.py
@@ -1,11 +1,21 @@
from typing import Optional, TypeVar
from typing import Any, Optional, Type, TypeVar, Union

import numpy as np

from docarray.base_document import BaseDocument
from docarray.documents import Audio
from docarray.typing import AnyEmbedding, AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.video.video_tensor import VideoTensor
from docarray.typing.url.video_url import VideoUrl

try:
import torch

torch_available = True
except ImportError:
torch_available = False

T = TypeVar('T', bound='Video')


Expand Down Expand Up @@ -83,3 +93,17 @@ class MultiModalDoc(BaseDocument):
tensor: Optional[VideoTensor]
key_frame_indices: Optional[AnyTensor]
embedding: Optional[AnyEmbedding]

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch_available and isinstance(value, torch.Tensor)
):
value = cls(tensor=value)

return super().validate(value)
27 changes: 27 additions & 0 deletions tests/integrations/predefined_document/test_audio.py
Expand Up @@ -6,6 +6,7 @@
import torch
from pydantic import parse_obj_as

from docarray import BaseDocument
from docarray.documents import Audio
from docarray.typing import AudioUrl
from docarray.typing.tensor.audio import AudioNdArray, AudioTorchTensor
Expand Down Expand Up @@ -83,3 +84,29 @@ class MyAudio(Audio):

assert isinstance(my_audio.tensor, AudioNdArray)
assert isinstance(my_audio.url, AudioUrl)


def test_audio_np():
audio = parse_obj_as(Audio, np.zeros((10, 10, 3)))
assert (audio.tensor == np.zeros((10, 10, 3))).all()


def test_audio_torch():
audio = parse_obj_as(Audio, torch.zeros(10, 10, 3))
assert (audio.tensor == torch.zeros(10, 10, 3)).all()


def test_audio_shortcut_doc():
class MyDoc(BaseDocument):
audio: Audio
audio2: Audio
audio3: Audio

doc = MyDoc(
audio='http://myurl.wav',
audio2=np.zeros((10, 10, 3)),
audio3=torch.zeros(10, 10, 3),
)
assert doc.audio.url == 'http://myurl.wav'
assert (doc.audio2.tensor == np.zeros((10, 10, 3))).all()
assert (doc.audio3.tensor == torch.zeros(10, 10, 3)).all()
35 changes: 34 additions & 1 deletion tests/integrations/predefined_document/test_image.py
@@ -1,6 +1,9 @@
import numpy as np
import pytest
import torch
from pydantic import parse_obj_as

from docarray import BaseDocument
from docarray.documents import Image

REMOTE_JPG = (
Expand All @@ -12,9 +15,39 @@
@pytest.mark.slow
@pytest.mark.internet
def test_image():

image = Image(url=REMOTE_JPG)

image.tensor = image.url.load()

assert isinstance(image.tensor, np.ndarray)


def test_image_str():
image = parse_obj_as(Image, 'http://myurl.jpg')
assert image.url == 'http://myurl.jpg'


def test_image_np():
image = parse_obj_as(Image, np.zeros((10, 10, 3)))
assert (image.tensor == np.zeros((10, 10, 3))).all()


def test_image_torch():
image = parse_obj_as(Image, torch.zeros(10, 10, 3))
assert (image.tensor == torch.zeros(10, 10, 3)).all()


def test_image_shortcut_doc():
class MyDoc(BaseDocument):
image: Image
image2: Image
image3: Image

doc = MyDoc(
image='http://myurl.jpg',
image2=np.zeros((10, 10, 3)),
image3=torch.zeros(10, 10, 3),
)
assert doc.image.url == 'http://myurl.jpg'
assert (doc.image2.tensor == np.zeros((10, 10, 3))).all()
assert (doc.image3.tensor == torch.zeros(10, 10, 3)).all()
18 changes: 18 additions & 0 deletions tests/integrations/predefined_document/test_mesh.py
@@ -1,6 +1,8 @@
import numpy as np
import pytest
from pydantic import parse_obj_as

from docarray import BaseDocument
from docarray.documents import Mesh3D
from tests import TOYDATA_DIR

Expand All @@ -19,3 +21,19 @@ def test_mesh(file_url):

assert isinstance(mesh.vertices, np.ndarray)
assert isinstance(mesh.faces, np.ndarray)


def test_str_init():
t = parse_obj_as(Mesh3D, 'http://hello.ply')
assert t.url == 'http://hello.ply'


def test_doc():
class MyDoc(BaseDocument):
mesh1: Mesh3D
mesh2: Mesh3D

doc = MyDoc(mesh1='http://hello.ply', mesh2=Mesh3D(url='http://hello.ply'))

assert doc.mesh1.url == 'http://hello.ply'
assert doc.mesh2.url == 'http://hello.ply'
29 changes: 29 additions & 0 deletions tests/integrations/predefined_document/test_point_cloud.py
@@ -1,6 +1,9 @@
import numpy as np
import pytest
import torch
from pydantic import parse_obj_as

from docarray import BaseDocument
from docarray.documents import PointCloud3D
from tests import TOYDATA_DIR

Expand All @@ -18,3 +21,29 @@ def test_point_cloud(file_url):
point_cloud.tensor = point_cloud.url.load(samples=100)

assert isinstance(point_cloud.tensor, np.ndarray)


def test_point_cloud_np():
image = parse_obj_as(PointCloud3D, np.zeros((10, 10, 3)))
assert (image.tensor == np.zeros((10, 10, 3))).all()


def test_point_cloud_torch():
image = parse_obj_as(PointCloud3D, torch.zeros(10, 10, 3))
assert (image.tensor == torch.zeros(10, 10, 3)).all()


def test_point_cloud_shortcut_doc():
class MyDoc(BaseDocument):
image: PointCloud3D
image2: PointCloud3D
image3: PointCloud3D

doc = MyDoc(
image='http://myurl.ply',
image2=np.zeros((10, 10, 3)),
image3=torch.zeros(10, 10, 3),
)
assert doc.image.url == 'http://myurl.ply'
assert (doc.image2.tensor == np.zeros((10, 10, 3))).all()
assert (doc.image3.tensor == torch.zeros(10, 10, 3)).all()
25 changes: 25 additions & 0 deletions tests/integrations/predefined_document/test_text.py
@@ -0,0 +1,25 @@
from pydantic import parse_obj_as

from docarray import BaseDocument
from docarray.documents import Text


def test_simple_init():
t = Text(text='hello')
assert t.text == 'hello'


def test_str_init():
t = parse_obj_as(Text, 'hello')
assert t.text == 'hello'


def test_doc():
class MyDoc(BaseDocument):
text1: Text
text2: Text

doc = MyDoc(text1='hello', text2=Text(text='world'))

assert doc.text1.text == 'hello'
assert doc.text2.text == 'world'

0 comments on commit 4311bcc

Please sign in to comment.