Skip to content

Commit 43a1b6d

Browse files
authored
0.0.16
- add file upload - apidocs generation fix for form
2 parents 050458c + 73dda47 commit 43a1b6d

File tree

12 files changed

+151
-51
lines changed

12 files changed

+151
-51
lines changed

README.md

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ from marshmallow import Schema, fields, ValidationError, post_load
2222
from starlette.applications import Starlette
2323
from starlette.datastructures import UploadFile
2424
from starlette.responses import JSONResponse
25+
from apispec.ext.marshmallow import MarshmallowPlugin
26+
from apispec import APISpec
2527

2628
from dataclasses import dataclass
29+
from datetime import datetime
2730

28-
from star_resty import Method, Operation, endpoint, json_schema, json_payload, form_payload, query, setup_spec
31+
from star_resty import Method, Operation, endpoint, json_schema, json_payload, upload, query, setup_spec, form_payload
2932
from typing import Optional
3033

3134
class EchoInput(Schema):
@@ -37,6 +40,7 @@ class JsonPayloadSchema(Schema):
3740
a = fields.Int(required=True)
3841
s = fields.String()
3942

43+
ma_plugin = MarshmallowPlugin()
4044

4145
# Json Payload (by dataclass)
4246
@dataclass
@@ -54,15 +58,9 @@ class JsonPayloadDataclass(Schema):
5458

5559

5660
# Form Payload
57-
class FormFile(fields.Field):
58-
def _validate(self, value):
59-
if not isinstance(value, UploadFile):
60-
raise ValidationError('Not a file')
61-
62-
6361
class FormPayload(Schema):
6462
id = fields.Int(required=True)
65-
file = FormFile()
63+
file_dt = fields.DateTime()
6664

6765

6866
app = Starlette(debug=True)
@@ -105,10 +103,15 @@ class PostDataclass(Method):
105103
class PostForm(Method):
106104
meta = Operation(tag='default', description='post form')
107105

108-
async def execute(self, form_data: form_payload(FormPayload)):
109-
file_name = form_data.get('file').filename
106+
async def execute(self, form_data: form_payload(FormPayload),
107+
files_reqired: upload('selfie', 'doc', required=True),
108+
files_optional: upload('file1', 'file2', 'file3')):
109+
files = {}
110+
for file in files_reqired + files_optional:
111+
body = await file.read()
112+
files[file.filename] = f"{body.hex()[:10]}..."
110113
id = form_data.get('id')
111-
return {'message': f"file {file_name} with id {id} received"}
114+
return {'message': f"files received (id: {id})", "files": files}
112115

113116

114117
if __name__ == '__main__':
@@ -118,4 +121,4 @@ if __name__ == '__main__':
118121
uvicorn.run(app, port=8080)
119122
```
120123

121-
Open [http://localhost:8080/apidocs.json](http://localhost:8080/apidocs.json) to view generated openapi schema.
124+
Open [http://localhost:8080/apidocs.json](http://localhost:8080/apidocs.json) to view generated openapi schema.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_packages(package):
3030
'apispec<4',
3131
'python-multipart'
3232
],
33-
version='0.0.15',
33+
version='0.0.16',
3434
url='https://github.com/slv0/start_resty',
3535
license='BSD',
3636
description='The web framework',

star_resty/apidocs/request.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ def resolve_parameters(endpoint: Method):
1111
return parameters
1212

1313
for p in parser:
14-
if p.schema is not None and p.location != 'body':
15-
parameters.append({'in': p.location, 'schema': p.schema})
14+
if not p.is_body:
15+
parameters.extend(p.get_spec())
1616

1717
return parameters
1818

@@ -30,8 +30,8 @@ def resolve_request_body(endpoint: Method):
3030
def resolve_request_body_content(parser: RequestParser):
3131
content = {}
3232
for p in parser:
33-
if p.schema is not None and p.location == 'body' and p.media_type:
34-
content[p.media_type] = {'schema': p.schema}
33+
if p.is_body:
34+
content.update(p.get_body_spec())
3535

3636
return content
3737

@@ -43,11 +43,7 @@ def resolve_request_body_params(endpoint: Method):
4343
return params
4444

4545
for p in parser:
46-
if p.schema is not None and p.location == 'body' and p.media_type:
47-
params.append({
48-
'name': 'body',
49-
'in': 'body',
50-
'schema': p.schema
51-
})
46+
if p.is_body:
47+
params.extend(p.get_spec())
5248

5349
return params

star_resty/method/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from starlette.requests import Request
88

9-
from star_resty.payload.parser import Parser
9+
from star_resty.payload.base import Parser
1010

1111
__all__ = ('RequestParser', 'create_parser')
1212

star_resty/payload/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from .form import form_payload, form_schema
12
from .header import header, header_schema
23
from .json import json_payload, json_schema
34
from .path import path, path_schema
45
from .query import query, query_schema
5-
from .form import form_payload, form_schema
6+
from .upload import upload

star_resty/payload/parser.py renamed to star_resty/payload/base.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,44 @@
11
import abc
22
import inspect
3-
from typing import Dict, Optional, Type, Union
3+
from functools import lru_cache
4+
from typing import Dict, Optional, Type, Union, Iterable, Mapping, Tuple
45

56
from marshmallow import EXCLUDE, Schema
67
from starlette.requests import Request
7-
from functools import lru_cache
88

9-
__all__ = ('Parser', 'set_parser')
9+
__all__ = ('Parser', 'SchemaParser', 'set_parser')
1010

1111

1212
class Parser(abc.ABC):
13+
__slots__ = ()
14+
15+
@abc.abstractmethod
16+
def parse(self, request: Request):
17+
raise NotImplementedError
18+
19+
@staticmethod
20+
def get_spec() -> Iterable[Mapping]:
21+
return ()
22+
23+
@staticmethod
24+
def get_body_spec() -> Iterable[Tuple[str, Mapping]]:
25+
return ()
26+
27+
@property
28+
def location(self) -> Optional[str]:
29+
return None
30+
31+
@property
32+
def media_type(self) -> Optional[str]:
33+
return None
34+
35+
@property
36+
def is_body(self) -> bool:
37+
return self.location == 'body'
38+
39+
40+
41+
class SchemaParser(Parser, metaclass=abc.ABCMeta):
1342
__slots__ = ('schema', 'unknown')
1443

1544
@classmethod
@@ -33,17 +62,12 @@ def __init__(self, schema: Schema, unknown=EXCLUDE):
3362
self.schema = schema
3463
self.unknown = unknown
3564

36-
@abc.abstractmethod
37-
def parse(self, request: Request):
38-
pass
65+
def get_spec(self):
66+
yield {'in': self.location, 'schema': self.schema}
3967

40-
@property
41-
def location(self) -> Optional[str]:
42-
return None
43-
44-
@property
45-
def media_type(self) -> Optional[str]:
46-
return None
68+
def get_body_spec(self):
69+
if self.media_type:
70+
yield self.media_type, {'schema': self.schema}
4771

4872

4973
def set_parser(parser: Parser):

star_resty/payload/form.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from starlette.requests import Request
66

77
from star_resty.exceptions import DecodeError
8-
from .parser import Parser, set_parser
8+
from .base import SchemaParser, set_parser
99

1010
__all__ = ('form_schema', 'form_payload', 'FormParser')
1111

@@ -22,12 +22,12 @@ def form_payload(schema: Union[Schema, Type[Schema]], unknown=EXCLUDE) -> Type[M
2222
return form_schema(schema, Mapping, unknown=unknown)
2323

2424

25-
class FormParser(Parser):
25+
class FormParser(SchemaParser):
2626
__slots__ = ()
2727

2828
@property
2929
def location(self):
30-
return 'body'
30+
return 'formData'
3131

3232
@property
3333
def media_type(self):
@@ -36,7 +36,7 @@ def media_type(self):
3636
async def parse(self, request: Request):
3737
try:
3838
form_data = await request.form()
39-
form_data = {} if not form_data else form_data
4039
except Exception as e:
41-
raise DecodeError('Invalid form data: %s' % (str(e))) from e
40+
raise DecodeError('Invalid form data: %s' % (str(e))) from e
41+
4242
return self.schema.load(form_data, unknown=self.unknown)

star_resty/payload/header.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from marshmallow import EXCLUDE, Schema
55
from starlette.requests import Request
66

7-
from .parser import Parser, set_parser
7+
from .base import SchemaParser, set_parser
88

99
__all__ = ('header', 'header_schema', 'HeaderParser')
1010

@@ -21,7 +21,7 @@ def header(schema: Union[Schema, Type[Schema]], unknown=EXCLUDE) -> Type[Mapping
2121
return header_schema(schema, Mapping, unknown=unknown)
2222

2323

24-
class HeaderParser(Parser):
24+
class HeaderParser(SchemaParser):
2525
__slots__ = ()
2626

2727
@property

star_resty/payload/json.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from starlette.requests import Request
77

88
from star_resty.exceptions import DecodeError
9-
from .parser import Parser, set_parser
9+
from .base import SchemaParser, set_parser
1010

1111
__all__ = ('json_schema', 'json_payload', 'JsonParser')
1212

@@ -23,7 +23,7 @@ def json_payload(schema: Union[Schema, Type[Schema]], unknown=EXCLUDE) -> Type[M
2323
return json_schema(schema, Mapping, unknown=unknown)
2424

2525

26-
class JsonParser(Parser):
26+
class JsonParser(SchemaParser):
2727
__slots__ = ()
2828

2929
@property

star_resty/payload/path.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from marshmallow import EXCLUDE, Schema
55
from starlette.requests import Request
66

7-
from .parser import Parser, set_parser
7+
from .base import SchemaParser, set_parser
88

99
__all__ = ('path', 'path_schema', 'PathParser')
1010

@@ -21,7 +21,7 @@ def path(schema: Union[Schema, Type[Schema]], unknown=EXCLUDE) -> Type[Mapping]:
2121
return path_schema(schema, Mapping, unknown=unknown)
2222

2323

24-
class PathParser(Parser):
24+
class PathParser(SchemaParser):
2525
__slots__ = ()
2626

2727
@property

star_resty/payload/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from marshmallow import EXCLUDE, Schema, fields
77
from starlette.requests import Request
88

9-
from .parser import Parser, set_parser
9+
from .base import SchemaParser, set_parser
1010

1111
__all__ = ('query', 'query_schema', 'QueryParser')
1212

@@ -23,7 +23,7 @@ def query(schema: Union[Schema, Type[Schema]], unknown=EXCLUDE) -> Type[Mapping]
2323
return query_schema(schema, Mapping, unknown=unknown)
2424

2525

26-
class QueryParser(Parser):
26+
class QueryParser(SchemaParser):
2727
__slots__ = ('fields',)
2828

2929
@classmethod

star_resty/payload/upload.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import abc
2+
from typing import Optional, Any, Sequence, Mapping, Type
3+
4+
from marshmallow import ValidationError
5+
from starlette.datastructures import UploadFile
6+
from starlette.requests import Request
7+
8+
from .base import Parser
9+
10+
__all__ = ('upload',)
11+
12+
13+
class UploadSequence(Sequence[UploadFile], metaclass=abc.ABCMeta):
14+
pass
15+
16+
17+
def upload(*args: str,
18+
description: Optional[str] = None,
19+
required: bool = False) -> Type[UploadSequence]:
20+
def helper() -> Any:
21+
return UploadParser(args, description=description, required=required)
22+
23+
return helper()
24+
25+
26+
class UploadParser(Parser):
27+
28+
def __init__(self, file_names: Sequence[str] = (), *,
29+
description: Optional[str] = None,
30+
required: bool = False):
31+
self.files_names = frozenset(file_names)
32+
self.description = description
33+
self.required = required
34+
35+
@property
36+
def parser(self):
37+
return self
38+
39+
@property
40+
def media_type(self):
41+
return 'multipart/form-data'
42+
43+
@property
44+
def location(self):
45+
return 'formData'
46+
47+
async def parse(self, request: Request):
48+
form = await request.form()
49+
res = []
50+
for key, val in form.items():
51+
if not isinstance(val, UploadFile):
52+
continue
53+
54+
if not self.files_names or key in self.files_names:
55+
res.append(val)
56+
57+
if self.required and not res:
58+
raise ValidationError(message='Missing required file', field_name='form')
59+
60+
return res
61+
62+
def get_spec(self):
63+
if self.files_names:
64+
for name in sorted(self.files_names):
65+
yield self._create_spec(name)
66+
else:
67+
yield self._create_spec('upfile')
68+
69+
def _create_spec(self, name: str) -> Mapping:
70+
return {
71+
'in': 'formData',
72+
'type': 'file',
73+
'description': self.description or '',
74+
'name': name,
75+
'required': self.required
76+
}

0 commit comments

Comments
 (0)