Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pip install sqlalchemy_adapter

## Simple Example

You can save and load policy to database.

```python
import sqlalchemy_adapter
import casbin
Expand All @@ -49,6 +51,22 @@ else:
pass
```

By default, policies are stored in the `casbin_rule` table.
You can custom the table where the policy is stored by using the `table_name` parameter.

```python

import sqlalchemy_adapter
import casbin

custom_table_name = "<custom_table_name>"

# create adapter with custom table name.
adapter = sqlalchemy_adapter.Adapter('sqlite:///test.db', table_name=custom_table_name)

e = casbin.Enforcer('path/to/model.conf', adapter)
```


### Getting Help

Expand Down
71 changes: 47 additions & 24 deletions sqlalchemy_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,46 @@ class Base(DeclarativeBase):
pass


class CasbinRule(Base):
__tablename__ = "casbin_rule"

id = Column(Integer, primary_key=True)
ptype = Column(String(255))
v0 = Column(String(255))
v1 = Column(String(255))
v2 = Column(String(255))
v3 = Column(String(255))
v4 = Column(String(255))
v5 = Column(String(255))

def __str__(self):
arr = [self.ptype]
for v in (self.v0, self.v1, self.v2, self.v3, self.v4, self.v5):
if v is None:
break
arr.append(v)
return ", ".join(arr)
def create_casbin_rule_class(table_name):
"""
Factory function to create a CasbinRule class with a custom table name.

Args:
table_name (str): Table name for the CasbinRule class.

Returns:
db_class (CasbinRule): The CasbinRule class.
"""

class CasbinRule(Base):
__tablename__ = table_name
__table_args__ = {"extend_existing": True}

id = Column(Integer, primary_key=True)
ptype = Column(String(255))
v0 = Column(String(255))
v1 = Column(String(255))
v2 = Column(String(255))
v3 = Column(String(255))
v4 = Column(String(255))
v5 = Column(String(255))

def __str__(self):
arr = [self.ptype]
for v in (self.v0, self.v1, self.v2, self.v3, self.v4, self.v5):
if v is None:
break
arr.append(v)
return ", ".join(arr)

def __repr__(self):
return '<CasbinRule {}: "{}">'.format(self.id, str(self))
def __repr__(self):
return '<CasbinRule {}: "{}">'.format(self.id, str(self))

return CasbinRule


# Export the default CasbinRule class with table name 'casbin_rule'.
CasbinRule = create_casbin_rule_class("casbin_rule")


class Filter:
Expand All @@ -56,14 +74,20 @@ class Filter:
class Adapter(persist.Adapter, persist.adapters.UpdateAdapter):
"""the interface for Casbin adapters."""

def __init__(self, engine, db_class=None, filtered=False):
def __init__(
self,
engine,
db_class=None,
table_name="casbin_rule",
filtered=False,
):
if isinstance(engine, str):
self._engine = create_engine(engine)
else:
self._engine = engine

if db_class is None:
db_class = CasbinRule
db_class = create_casbin_rule_class(table_name=table_name)
else:
for attr in (
"id",
Expand Down Expand Up @@ -281,7 +305,6 @@ def _update_filtered_policies(self, new_rules, filter) -> [[str]]:
"""_update_filtered_policies updates all the policies on the basis of the filter."""

with self._session_scope() as session:

# Load old policies

query = session.query(self._db_class).filter(
Expand Down
27 changes: 26 additions & 1 deletion tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlalchemy_adapter import Adapter
from sqlalchemy_adapter import Base
from sqlalchemy_adapter import CasbinRule
from sqlalchemy_adapter.adapter import Filter
from sqlalchemy_adapter.adapter import Filter, create_casbin_rule_class


def get_fixture(path):
Expand Down Expand Up @@ -36,6 +36,25 @@ def get_enforcer():
return casbin.Enforcer(get_fixture("rbac_model.conf"), adapter)


def get_custom_table_name_enforcer():
engine = create_engine("sqlite://")
table_name = "custom_casbin_rule_table"
adapter = Adapter(engine, table_name=table_name)

session = sessionmaker(bind=engine)
Base.metadata.create_all(engine)
s = session()

CustomTableCasbinRule = create_casbin_rule_class(table_name)

s.query(CustomTableCasbinRule).delete()
s.add(CustomTableCasbinRule(ptype="p", v0="alice", v1="data1", v2="read"))
s.commit()
s.close()

return casbin.Enforcer(get_fixture("rbac_model.conf"), adapter)


class TestConfig(TestCase):
def test_custom_db_class(self):
class CustomRule(Base):
Expand All @@ -61,6 +80,12 @@ class CustomRule(Base):
s.commit()
self.assertEqual(s.query(CustomRule).all()[0].not_exist, "NotNone")

def test_custom_table_name(self):
e = get_custom_table_name_enforcer()

self.assertTrue(e.enforce("alice", "data1", "read"))
self.assertFalse(e.enforce("bob", "data2", "write"))

def test_enforcer_basic(self):
e = get_enforcer()

Expand Down