diff --git a/india_compliance/gst_india/overrides/test_transaction.py b/india_compliance/gst_india/overrides/test_transaction.py index 829ad31ff..d5e9586d2 100644 --- a/india_compliance/gst_india/overrides/test_transaction.py +++ b/india_compliance/gst_india/overrides/test_transaction.py @@ -199,10 +199,12 @@ def test_validate_item_tax_template(self): item_tax_template.save() def test_transaction_for_items_with_duplicate_taxes(self): - # Should not allow same item in invoice with multiple taxes - doc = create_transaction(**self.transaction_details, do_not_save=True) - + # Should not allow same item in invoice with multiple taxes if any of the tax row has dont_recompute_tax set to 1 + doc = create_transaction( + **self.transaction_details, do_not_save=True, is_in_state=True + ) append_item(doc, frappe._dict(item_tax_template="GST 28% - _TIRC")) + doc.taxes[0].dont_recompute_tax = 1 self.assertRaisesRegex( frappe.exceptions.ValidationError, @@ -210,6 +212,25 @@ def test_transaction_for_items_with_duplicate_taxes(self): doc.insert, ) + def test_transaction_for_items_with_different_tax_templates(self): + doc = create_transaction( + **self.transaction_details, do_not_save=True, is_in_state=True + ) + + append_item(doc, frappe._dict(item_tax_template="GST 12% - _TIRC")) + doc.insert() + + # Verify that taxes and amounts are set correctly in both items + self.assertEqual(doc.items[0].cgst_rate, 9) + self.assertEqual(doc.items[0].sgst_rate, 9) + self.assertEqual(doc.items[0].cgst_amount, 9) + self.assertEqual(doc.items[0].sgst_amount, 9) + + self.assertEqual(doc.items[1].cgst_rate, 6) + self.assertEqual(doc.items[1].sgst_rate, 6) + self.assertEqual(doc.items[1].cgst_amount, 6) + self.assertEqual(doc.items[1].sgst_amount, 6) + def test_place_of_supply_is_set(self): doc = create_transaction(**self.transaction_details) diff --git a/india_compliance/gst_india/overrides/transaction.py b/india_compliance/gst_india/overrides/transaction.py index 2238db327..cdbb30220 100644 --- a/india_compliance/gst_india/overrides/transaction.py +++ b/india_compliance/gst_india/overrides/transaction.py @@ -583,6 +583,10 @@ def validate_items(doc): if not doc.get("items"): return + # Only validate if any tax row has dont_recompute_tax enabled + if not any(row.get("dont_recompute_tax") for row in doc.taxes): + return + item_tax_templates = frappe._dict() items_with_duplicate_taxes = [] @@ -1165,6 +1169,59 @@ def get_item_defaults(self): self.item_defaults = item_defaults + def calculate_item_wise_total_tax_amount(self): + """ + Calculate total tax amount for each item for each tax type (item_name / item_code is the key) + Set tax_rate used for each item row for each tax type (item.idx is the key) + + Example: + + item_wise_total_tax_amount = { + "item_key": { + "tax_row1.name": sum(taxable_value * tax_rate1 for each row with same item_key), + "tax_row2.name": sum(taxable_value * tax_rate2 for each row with same item_key), + } + } + item_row_wise_tax_rate = { + item1.idx: {"tax_row1.name": tax_rate1, "tax_row2.name": tax_rate2}, + item2.idx: {"tax_row1.name": tax_rate1, "tax_row2.name": tax_rate2}, + } + """ + + self.item_wise_total_tax_amount = frappe._dict() + item_tax_rates = frappe._dict( + { + item.idx: frappe.parse_json(item.get("item_tax_rate")) + for item in self.doc.get("items") + } + ) + self.item_row_wise_tax_rate = frappe._dict() + + for row in self.doc.taxes: + if ( + not row.base_tax_amount_after_discount_amount + or row.gst_tax_type not in GST_TAX_TYPES + or not row.item_wise_tax_detail + ): + continue + + for item in self.doc.get("items"): + item_key = self.get_item_key(item) + item_tax_rate = item_tax_rates[item.idx] + tax_rate = item_tax_rate.get(row.account_head) or row.rate + + item_idx_tax_rate_map = self.item_row_wise_tax_rate.setdefault( + item.idx, {} + ) + item_idx_tax_rate_map[row.name] = tax_rate + + item_tax_type_map = self.item_wise_total_tax_amount.setdefault( + item_key, {} + ) + total_tax_amount = item_tax_type_map.get(row.name) or 0 + total_tax_amount += item.taxable_value * (tax_rate / 100) + item_tax_type_map[row.name] = total_tax_amount + def set_item_wise_tax_details(self): """ Item Tax Details complied @@ -1185,11 +1242,13 @@ def set_item_wise_tax_details(self): - There could be more than one row for same account - Item count added to handle rounding errors """ + self.calculate_item_wise_total_tax_amount() tax_details = frappe._dict() + self.item_wise_tax_details = frappe._dict() for row in self.doc.get("items"): - key = row.item_code or row.item_name + key = self.get_item_key(row) if key not in tax_details: tax_details[key] = self.item_defaults.copy() @@ -1208,6 +1267,7 @@ def set_item_wise_tax_details(self): tax_amount_field = f"{tax}_amount" old = json.loads(row.item_wise_tax_detail) + self.item_wise_tax_details[row.name] = old tax_difference = row.base_tax_amount_after_discount_amount last_item_with_tax = None @@ -1242,9 +1302,32 @@ def set_item_wise_tax_details(self): # Handle rounding errors if tax_difference and last_item_with_tax: last_item_with_tax[tax_amount_field] += tax_difference - self.item_tax_details = tax_details + def get_item_row_tax_rate(self, item, tax_row, default_rate): + item_idx_tax_rate_map = self.item_row_wise_tax_rate.get(item.idx) + + if not item_idx_tax_rate_map: + return default_rate + + tax_rate = item_idx_tax_rate_map.get(tax_row.name) + + return tax_rate if tax_rate is not None else default_rate + + def get_item_row_tax_amount_factor(self, item, tax_row): + key = self.get_item_key(item) + item_tax_type_map = self.item_wise_total_tax_amount.get(key) or {} + total_tax_amount = item_tax_type_map.get(tax_row.name) + + if not total_tax_amount: + return 1 + + item_wise_tax_detail = self.item_wise_tax_details.get(tax_row.name) or {} + total_tax_amount_used = item_wise_tax_detail.get(key).get("tax_amount", 0) + tax_amount_factor = total_tax_amount_used / total_tax_amount + + return tax_amount_factor or 1 + def update_item_tax_details(self): for item in self.doc.get("items"): item.update(self.get_item_tax_detail(item)) @@ -1277,8 +1360,17 @@ def get_item_tax_detail(self, item): # Handle rounding errors response = item_tax_detail.copy() - for tax in GST_TAX_TYPES: - if (tax_rate := item_tax_detail[f"{tax}_rate"]) == 0: + for row in self.doc.taxes: + if row.gst_tax_type not in GST_TAX_TYPES: + continue + + tax = row.gst_tax_type + tax_rate_field = f"{tax}_rate" + tax_rate = self.get_item_row_tax_rate( + item, row, item_tax_detail[tax_rate_field] + ) + + if tax == 0: continue tax_amount_field = f"{tax}_amount" @@ -1287,11 +1379,17 @@ def get_item_tax_detail(self, item): multiplier = ( item.qty if tax == "cess_non_advol" else item.taxable_value / 100 ) - tax_amount = flt(tax_rate * multiplier, precision) + tax_amount_factor = ( + 1 + if tax == "cess_non_advol" + else self.get_item_row_tax_amount_factor(item, row) + ) + + tax_amount = flt(tax_amount_factor * tax_rate * multiplier, precision) item_tax_detail[tax_amount_field] -= tax_amount - response.update({tax_amount_field: tax_amount}) + response.update({tax_amount_field: tax_amount, tax_rate_field: tax_rate}) return response