import os
from django.conf import settings
from django.core.management import BaseCommand
from django_test_tools.generators.crud_generator import GenericTemplateWriter
from django_test_tools.generators.model_generator import FactoryBoyGenerator
from ...app_manager import DjangoAppManager
PRINT_IMPORTS = """
import string
from random import randint
from pytz import timezone
from django.conf import settings
from factory import Iterator
from factory import LazyAttribute
from factory import SubFactory
from factory import lazy_attribute
from factory.django import DjangoModelFactory, FileField
from factory.fuzzy import FuzzyText, FuzzyInteger
from faker import Factory as FakerFactory
faker = FakerFactory.create()
"""
PRINT_FACTORY_CLASS = """
class {0}Factory(DjangoModelFactory):
class Meta:
model = {0}
"""
PRINT_CHARFIELD = """ {} = LazyAttribute(lambda x: faker.text(max_nb_chars={}))"""
PRINT_CHARFIELD_NUM = """ {} = LazyAttribute(lambda x: FuzzyText(length={}, chars=string.digits).fuzz())"""
PRINT_CHARFIELD_LETTERS = """ {} = LazyAttribute(lambda x: FuzzyText(length={}, chars=string.ascii_letters).fuzz())"""
PRINT_CHARFIELD_CHOICES = """ {} = Iterator({}.{}, getter=lambda x: x[0])"""
PRINT_DATETIMEFIELD = """ {} = LazyAttribute(lambda x: faker.date_time_between(start_date="-1y", end_date="now",
tzinfo=timezone(settings.TIME_ZONE)))"""
PRINT_FOREIGNKEY = """ {} = SubFactory({}Factory){}"""
PRINT_FILEFIELD = """ {} = FileField(filename='{}.{}')"""
PRINT_BOOLEANFIELD = """ {} = Iterator([True, False])"""
PRINT_INTEGERFIELD = """ {} = LazyAttribute(lambda o: randint(1, 100))"""
PRINT_IPADDRESSFIELD = """ {} = LazyAttribute(lambda o: faker.ipv4(network=False))"""
PRINT_TEXTFIELD = """ {} = LazyAttribute(lambda x: faker.paragraph(nb_sentences=3, variable_nb_sentences=True))"""
PRINT_DECIMALFIELD = """ {} = LazyAttribute(lambda x: faker.pydecimal(left_digits={}, right_digits={}, positive=True))"""
PRINT_UNSUPPORTED_FIELD = """ #{} = {} We do not support this field type"""
PRINT_COUNTRYFIELD = """ {} = Iterator(['PA', 'US'])"""
# noinspection PyProtectedMember
[docs]class ModelFactoryGenerator(object):
def __init__(self, model):
self.model = model
def _generate(self):
factory_class_content = list()
factory_class_content.append({'print': PRINT_FACTORY_CLASS, 'args': [self.model.__name__]})
for field in self.model._meta.fields:
field_type = type(field).__name__
field_data = dict()
if field_type in ['AutoField', 'AutoCreatedField', 'AutoLastModifiedField']:
pass
elif field_type in ['DateTimeField', 'DateField']:
field_data = {'print': PRINT_DATETIMEFIELD, 'args': [field.name]}
factory_class_content.append(field_data)
elif field_type == 'CharField':
field_data = self._get_charfield(field)
factory_class_content.append(field_data)
elif field_type == 'ForeignKey':
related_model = field.remote_field.model.__name__
field_data = {'print': PRINT_FOREIGNKEY,
'args': [field.name, related_model, '']}
factory_class_content.append(field_data)
elif field_type == 'BooleanField':
field_data = {'print': PRINT_BOOLEANFIELD, 'args': [field.name]}
factory_class_content.append(field_data)
elif field_type == 'TextField':
field_data = {'print': PRINT_TEXTFIELD, 'args': [field.name]}
factory_class_content.append(field_data)
elif field_type == 'IntegerField':
field_data = {'print': PRINT_INTEGERFIELD, 'args': [field.name]}
factory_class_content.append(field_data)
elif field_type == 'FileField':
field_data = {'print': PRINT_FILEFIELD, 'args': [field.name, field.name, 'xlsx']}
factory_class_content.append(field_data)
elif field_type == 'DecimalField' or field_type == 'MoneyField':
max_left = field.max_digits - field.decimal_places
max_right = field.decimal_places
field_data = {'print': PRINT_DECIMALFIELD,
'args': [field.name, max_left, max_right]}
factory_class_content.append(field_data)
elif field_type == 'GenericIPAddressField':
field_data = {'print': PRINT_IPADDRESSFIELD, 'args': [field.name]}
factory_class_content.append(field_data)
elif field_type == 'CountryField':
field_data = {'print': PRINT_COUNTRYFIELD, 'args': [field.name]}
factory_class_content.append(field_data)
else:
field_data = {'print': PRINT_UNSUPPORTED_FIELD,
'args': [field.name, field_type]}
factory_class_content.append(field_data)
return factory_class_content
def _get_charfield(self, field):
field_data = dict()
if field.choices is not None and len(field.choices) > 0:
field_data = {'print': PRINT_CHARFIELD_CHOICES, 'args': [field.name, self.model.__name__, 'CHOICES']}
return field_data
else:
if self._is_number(field.name):
field_data = {'print': PRINT_CHARFIELD_NUM,
'args': [field.name, field.max_length]}
return field_data
else:
if field.max_length >= 5:
field_data = {'print': PRINT_CHARFIELD,
'args': [field.name, field.max_length]}
else:
field_data = {'print': PRINT_CHARFIELD_LETTERS,
'args': [field.name, field.max_length]}
return field_data
def _is_number(self, field_name):
num_vals = ['id', 'num']
for nv in num_vals:
if nv in field_name.lower():
return True
return False
def __str__(self):
printable = list()
for print_data in self._generate():
try:
printable.append(print_data['print'].format(*print_data['args']))
except IndexError as e:
print('-' * 74)
print('{print} {args}'.format(**print_data))
raise e
return '\n'.join(printable)
[docs]class Command(BaseCommand):
"""
$ python manage.py generate_factories project.app
"""
[docs] def add_arguments(self, parser):
parser.add_argument('app_name')
# parser.add_argument("-l", "--list",
# action='store_true',
# dest="list",
# help="List employees",
# )
# parser.add_argument("-a", "--assign",
# action='store_true',
# dest="assign",
# help="Create unit assignments",
# )
#
#
parser.add_argument("--filename",
dest="filename",
help="Output filename",
default=None,
)
# parser.add_argument("--start-date",
# dest="start_date",
# help="Start date for the assignment",
# default=None,
# )
# parser.add_argument("--fiscal-year",
# dest="fiscal_year",
# help="Fiscal year for assignments",
# default=None,
# )
# parser.add_argument("-u", "--username",
# dest="usernames",
# help="LDAP usernames for employees",
# nargs='+',
# )
[docs] def handle(self, *args, **options):
app_name = options.get('app_name')
if options.get('filename'):
filename = os.path.join(settings.TEST_OUTPUT_PATH, options.get('filename'))
generator = FactoryBoyGenerator()
factory_data = generator.create_template_data(app_name)
template_name = 'factories.py.j2'
writer = GenericTemplateWriter(template_name)
writer.write(factory_data, filename)
else:
app_manager = DjangoAppManager()
app = app_manager.get_app(app_name)
if not app:
self.stderr.write('This command requires an existing app name as '
'argument')
self.stderr.write('Available apps:')
for app in sorted(app_manager.installed_apps):
self.stderr.write(' %s' % app)
else:
self.stdout.write(PRINT_IMPORTS)
for model in app.get_models():
model_fact = ModelFactoryGenerator(model)
self.stdout.write(str(model_fact))